diff --git a/3rdparty/glslang/.gitattributes b/3rdparty/glslang/.gitattributes old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/.travis.yml b/3rdparty/glslang/.travis.yml old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/SPIRV/CMakeLists.txt b/3rdparty/glslang/SPIRV/CMakeLists.txt old mode 100755 new mode 100644 index 48ca77953..1997e74c3 --- a/3rdparty/glslang/SPIRV/CMakeLists.txt +++ b/3rdparty/glslang/SPIRV/CMakeLists.txt @@ -5,6 +5,7 @@ set(SOURCES SpvBuilder.cpp SpvPostProcess.cpp doc.cpp + SpvTools.cpp disassemble.cpp) set(SPVREMAP_SOURCES @@ -23,6 +24,7 @@ set(HEADERS SpvBuilder.h spvIR.h doc.h + SpvTools.h disassemble.h) set(SPVREMAP_HEADERS diff --git a/3rdparty/glslang/SPIRV/GLSL.ext.KHR.h b/3rdparty/glslang/SPIRV/GLSL.ext.KHR.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/SPIRV/GLSL.std.450.h b/3rdparty/glslang/SPIRV/GLSL.std.450.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/SPIRV/GlslangToSpv.cpp b/3rdparty/glslang/SPIRV/GlslangToSpv.cpp old mode 100755 new mode 100644 index 8205e740d..cffeb2017 --- a/3rdparty/glslang/SPIRV/GlslangToSpv.cpp +++ b/3rdparty/glslang/SPIRV/GlslangToSpv.cpp @@ -54,15 +54,6 @@ namespace spv { #endif } -#if ENABLE_OPT - #include "spirv-tools/optimizer.hpp" - #include "message.h" -#endif - -#if ENABLE_OPT -using namespace spvtools; -#endif - // Glslang includes #include "../glslang/MachineIndependent/localintermediate.h" #include "../glslang/MachineIndependent/SymbolTable.h" @@ -3220,7 +3211,7 @@ void TGlslangToSpvTraverser::updateMemberOffset(const glslang::TType& structType // adjusting this late means inconsistencies with earlier code, which for reflection is an issue // Until reflection is brought in sync with these adjustments, don't apply to $Global, // which is the most likely to rely on reflection, and least likely to rely implicit layouts - if (glslangIntermediate->usingHlslOFfsets() && + if (glslangIntermediate->usingHlslOffsets() && ! memberType.isArray() && memberType.isVector() && structType.getTypeName().compare("$Global") != 0) { int dummySize; int componentAlignment = glslangIntermediate->getBaseAlignmentScalar(memberType, dummySize); @@ -6971,7 +6962,7 @@ void OutputSpvHex(const std::vector& spirv, const char* baseName, if (out.fail()) printf("ERROR: Failed to open file: %s\n", baseName); out << "\t// " << - glslang::GetSpirvGeneratorVersion() << "." << GLSLANG_MINOR_VERSION << "." << GLSLANG_PATCH_LEVEL << + GetSpirvGeneratorVersion() << "." << GLSLANG_MINOR_VERSION << "." << GLSLANG_PATCH_LEVEL << std::endl; if (varName != nullptr) { out << "\t #pragma once" << std::endl; @@ -6998,13 +6989,13 @@ void OutputSpvHex(const std::vector& spirv, const char* baseName, // // Set up the glslang traversal // -void GlslangToSpv(const glslang::TIntermediate& intermediate, std::vector& spirv, SpvOptions* options) +void GlslangToSpv(const TIntermediate& intermediate, std::vector& spirv, SpvOptions* options) { spv::SpvBuildLogger logger; GlslangToSpv(intermediate, spirv, &logger, options); } -void GlslangToSpv(const glslang::TIntermediate& intermediate, std::vector& spirv, +void GlslangToSpv(const TIntermediate& intermediate, std::vector& spirv, spv::SpvBuildLogger* logger, SpvOptions* options) { TIntermNode* root = intermediate.getTreeRoot(); @@ -7012,11 +7003,11 @@ void GlslangToSpv(const glslang::TIntermediate& intermediate, std::vectortraverse(&it); @@ -7026,53 +7017,18 @@ void GlslangToSpv(const glslang::TIntermediate& intermediate, std::vectoroptimizeSize) && - !options->disableOptimizer) { - spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2; + if ((intermediate.getSource() == EShSourceHlsl || options->optimizeSize) && !options->disableOptimizer) + SpirvToolsLegalize(intermediate, spirv, logger, options); - spvtools::Optimizer optimizer(target_env); - optimizer.SetMessageConsumer([](spv_message_level_t level, - const char* source, - const spv_position_t& position, - const char* message) { - std::cerr << StringifyMessage(level, source, position, message) - << std::endl; - }); + if (options->validate) + SpirvToolsValidate(intermediate, spirv, logger); - optimizer.RegisterPass(CreateMergeReturnPass()); - optimizer.RegisterPass(CreateInlineExhaustivePass()); - optimizer.RegisterPass(CreateEliminateDeadFunctionsPass()); - optimizer.RegisterPass(CreateScalarReplacementPass()); - optimizer.RegisterPass(CreateLocalAccessChainConvertPass()); - optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()); - optimizer.RegisterPass(CreateLocalSingleStoreElimPass()); - optimizer.RegisterPass(CreateSimplificationPass()); - optimizer.RegisterPass(CreateAggressiveDCEPass()); - optimizer.RegisterPass(CreateDeadInsertElimPass()); - optimizer.RegisterPass(CreateAggressiveDCEPass()); - optimizer.RegisterPass(CreateDeadBranchElimPass()); - optimizer.RegisterPass(CreateBlockMergePass()); - optimizer.RegisterPass(CreateLocalMultiStoreElimPass()); - optimizer.RegisterPass(CreateIfConversionPass()); - optimizer.RegisterPass(CreateSimplificationPass()); - optimizer.RegisterPass(CreateAggressiveDCEPass()); - optimizer.RegisterPass(CreateDeadInsertElimPass()); - if (options->optimizeSize) { - optimizer.RegisterPass(CreateRedundancyEliminationPass()); - // TODO(greg-lunarg): Add this when AMD driver issues are resolved - // optimizer.RegisterPass(CreateCommonUniformElimPass()); - } - optimizer.RegisterPass(CreateAggressiveDCEPass()); - optimizer.RegisterPass(CreateCFGCleanupPass()); - optimizer.RegisterLegalizationPasses(); + if (options->disassemble) + SpirvToolsDisassemble(std::cout, spirv); - if (!optimizer.Run(spirv.data(), spirv.size(), &spirv)) - return; - } #endif - glslang::GetThreadPoolAllocator().pop(); + GetThreadPoolAllocator().pop(); } }; // end namespace glslang diff --git a/3rdparty/glslang/SPIRV/GlslangToSpv.h b/3rdparty/glslang/SPIRV/GlslangToSpv.h old mode 100644 new mode 100755 index f7f7cff62..4169c12e9 --- a/3rdparty/glslang/SPIRV/GlslangToSpv.h +++ b/3rdparty/glslang/SPIRV/GlslangToSpv.h @@ -38,6 +38,7 @@ #pragma warning(disable : 4464) // relative include path contains '..' #endif +#include "SpvTools.h" #include "../glslang/Include/intermediate.h" #include @@ -47,14 +48,6 @@ namespace glslang { -struct SpvOptions { - SpvOptions() : generateDebugInfo(false), disableOptimizer(true), - optimizeSize(false) { } - bool generateDebugInfo; - bool disableOptimizer; - bool optimizeSize; -}; - void GetSpirvVersion(std::string&); int GetSpirvGeneratorVersion(); void GlslangToSpv(const glslang::TIntermediate& intermediate, std::vector& spirv, diff --git a/3rdparty/glslang/SPIRV/SPVRemapper.cpp b/3rdparty/glslang/SPIRV/SPVRemapper.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/SPIRV/SPVRemapper.h b/3rdparty/glslang/SPIRV/SPVRemapper.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/SPIRV/SpvBuilder.cpp b/3rdparty/glslang/SPIRV/SpvBuilder.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/SPIRV/SpvBuilder.h b/3rdparty/glslang/SPIRV/SpvBuilder.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/SPIRV/SpvPostProcess.cpp b/3rdparty/glslang/SPIRV/SpvPostProcess.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/SPIRV/SpvTools.cpp b/3rdparty/glslang/SPIRV/SpvTools.cpp new file mode 100755 index 000000000..4807b4255 --- /dev/null +++ b/3rdparty/glslang/SPIRV/SpvTools.cpp @@ -0,0 +1,188 @@ +// +// Copyright (C) 2014-2016 LunarG, Inc. +// Copyright (C) 2018 Google, Inc. +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// +// Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// +// Neither the name of 3Dlabs Inc. Ltd. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +// +// Call into SPIRV-Tools to disassemble, validate, and optimize. +// + +#if ENABLE_OPT + +#include +#include + +#include "SpvTools.h" +#include "spirv-tools/optimizer.hpp" +#include "spirv-tools/libspirv.h" + +namespace glslang { + +// Translate glslang's view of target versioning to what SPIRV-Tools uses. +spv_target_env MapToSpirvToolsEnv(const SpvVersion& spvVersion, spv::SpvBuildLogger* logger) +{ + switch (spvVersion.vulkan) { + case glslang::EShTargetVulkan_1_0: return spv_target_env::SPV_ENV_VULKAN_1_0; + case glslang::EShTargetVulkan_1_1: return spv_target_env::SPV_ENV_VULKAN_1_1; + default: + break; + } + + if (spvVersion.openGl > 0) + return spv_target_env::SPV_ENV_OPENGL_4_5; + + logger->missingFunctionality("Target version for SPIRV-Tools validator"); + return spv_target_env::SPV_ENV_UNIVERSAL_1_0; +} + + +// Use the SPIRV-Tools disassembler to print SPIR-V. +void SpirvToolsDisassemble(std::ostream& out, const std::vector& spirv) +{ + // disassemble + spv_context context = spvContextCreate(SPV_ENV_UNIVERSAL_1_3); + spv_text text; + spv_diagnostic diagnostic = nullptr; + spvBinaryToText(context, spirv.data(), spirv.size(), + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT, + &text, &diagnostic); + + // dump + if (diagnostic == nullptr) + out << text->str; + else + spvDiagnosticPrint(diagnostic); + + // teardown + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(context); +} + +// Apply the SPIRV-Tools validator to generated SPIR-V. +void SpirvToolsValidate(const glslang::TIntermediate& intermediate, std::vector& spirv, + spv::SpvBuildLogger* logger) +{ + // validate + spv_context context = spvContextCreate(MapToSpirvToolsEnv(intermediate.getSpv(), logger)); + spv_const_binary_t binary = { spirv.data(), spirv.size() }; + spv_diagnostic diagnostic = nullptr; + spv_validator_options options = spvValidatorOptionsCreate(); + spvValidatorOptionsSetRelaxBlockLayout(options, intermediate.usingHlslOffsets()); + spvValidateWithOptions(context, options, &binary, &diagnostic); + + // report + if (diagnostic != nullptr) { + logger->error("SPIRV-Tools Validation Errors"); + logger->error(diagnostic->error); + } + + // tear down + spvValidatorOptionsDestroy(options); + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(context); +} + +// Apply the SPIRV-Tools optimizer to generated SPIR-V, for the purpose of +// legalizing HLSL SPIR-V. +void SpirvToolsLegalize(const glslang::TIntermediate& intermediate, std::vector& spirv, + spv::SpvBuildLogger* logger, const SpvOptions* options) +{ + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2; + + spvtools::Optimizer optimizer(target_env); + optimizer.SetMessageConsumer( + [](spv_message_level_t level, const char *source, const spv_position_t &position, const char *message) { + auto &out = std::cerr; + switch (level) + { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + out << "error: "; + break; + case SPV_MSG_WARNING: + out << "warning: "; + break; + case SPV_MSG_INFO: + case SPV_MSG_DEBUG: + out << "info: "; + break; + default: + break; + } + if (source) + { + out << source << ":"; + } + out << position.line << ":" << position.column << ":" << position.index << ":"; + if (message) + { + out << " " << message; + } + out << std::endl; + }); + + optimizer.RegisterPass(spvtools::CreateMergeReturnPass()); + optimizer.RegisterPass(spvtools::CreateInlineExhaustivePass()); + optimizer.RegisterPass(spvtools::CreateEliminateDeadFunctionsPass()); + optimizer.RegisterPass(spvtools::CreateScalarReplacementPass()); + optimizer.RegisterPass(spvtools::CreateLocalAccessChainConvertPass()); + optimizer.RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass()); + optimizer.RegisterPass(spvtools::CreateLocalSingleStoreElimPass()); + optimizer.RegisterPass(spvtools::CreateSimplificationPass()); + optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass()); + optimizer.RegisterPass(spvtools::CreateVectorDCEPass()); + optimizer.RegisterPass(spvtools::CreateDeadInsertElimPass()); + optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass()); + optimizer.RegisterPass(spvtools::CreateDeadBranchElimPass()); + optimizer.RegisterPass(spvtools::CreateBlockMergePass()); + optimizer.RegisterPass(spvtools::CreateLocalMultiStoreElimPass()); + optimizer.RegisterPass(spvtools::CreateIfConversionPass()); + optimizer.RegisterPass(spvtools::CreateSimplificationPass()); + optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass()); + optimizer.RegisterPass(spvtools::CreateVectorDCEPass()); + optimizer.RegisterPass(spvtools::CreateDeadInsertElimPass()); + if (options->optimizeSize) { + optimizer.RegisterPass(spvtools::CreateRedundancyEliminationPass()); + // TODO(greg-lunarg): Add this when AMD driver issues are resolved + // optimizer.RegisterPass(CreateCommonUniformElimPass()); + } + optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass()); + optimizer.RegisterPass(spvtools::CreateCFGCleanupPass()); + + optimizer.Run(spirv.data(), spirv.size(), &spirv, spvtools::ValidatorOptions(), true); +} + +}; // end namespace glslang + +#endif diff --git a/3rdparty/glslang/SPIRV/SpvTools.h b/3rdparty/glslang/SPIRV/SpvTools.h new file mode 100755 index 000000000..08bcf3a28 --- /dev/null +++ b/3rdparty/glslang/SPIRV/SpvTools.h @@ -0,0 +1,80 @@ +// +// Copyright (C) 2014-2016 LunarG, Inc. +// Copyright (C) 2018 Google, Inc. +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// +// Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// +// Neither the name of 3Dlabs Inc. Ltd. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +// +// Call into SPIRV-Tools to disassemble, validate, and optimize. +// + +#pragma once +#ifndef GLSLANG_SPV_TOOLS_H +#define GLSLANG_SPV_TOOLS_H + +#include +#include + +#include "../glslang/MachineIndependent/localintermediate.h" +#include "Logger.h" + +namespace glslang { + +struct SpvOptions { + SpvOptions() : generateDebugInfo(false), disableOptimizer(true), + optimizeSize(false), disassemble(false), validate(false) { } + bool generateDebugInfo; + bool disableOptimizer; + bool optimizeSize; + bool disassemble; + bool validate; +}; + +#if ENABLE_OPT + +// Use the SPIRV-Tools disassembler to print SPIR-V. +void SpirvToolsDisassemble(std::ostream& out, const std::vector& spirv); + +// Apply the SPIRV-Tools validator to generated SPIR-V. +void SpirvToolsValidate(const glslang::TIntermediate& intermediate, std::vector& spirv, + spv::SpvBuildLogger*); + +// Apply the SPIRV-Tools optimizer to generated SPIR-V, for the purpose of +// legalizing HLSL SPIR-V. +void SpirvToolsLegalize(const glslang::TIntermediate& intermediate, std::vector& spirv, + spv::SpvBuildLogger*, const SpvOptions*); + +#endif + +}; // end namespace glslang + +#endif // GLSLANG_SPV_TOOLS_H \ No newline at end of file diff --git a/3rdparty/glslang/SPIRV/disassemble.cpp b/3rdparty/glslang/SPIRV/disassemble.cpp old mode 100755 new mode 100644 index a8efd693f..6f3160913 --- a/3rdparty/glslang/SPIRV/disassemble.cpp +++ b/3rdparty/glslang/SPIRV/disassemble.cpp @@ -46,6 +46,7 @@ #include "disassemble.h" #include "doc.h" +#include "SpvTools.h" namespace spv { extern "C" { @@ -716,25 +717,4 @@ void Disassemble(std::ostream& out, const std::vector& stream) SpirvStream.processInstructions(); } -#if ENABLE_OPT - -#include "spirv-tools/libspirv.h" - -// Use the SPIRV-Tools disassembler to print SPIR-V. -void SpirvToolsDisassemble(std::ostream& out, const std::vector& spirv) -{ - spv_context context = spvContextCreate(SPV_ENV_UNIVERSAL_1_3); - spv_text text; - spv_diagnostic diagnostic = nullptr; - spvBinaryToText(context, &spirv.front(), spirv.size(), - SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT, - &text, &diagnostic); - if (diagnostic == nullptr) - out << text->str; - else - spvDiagnosticPrint(diagnostic); -} - -#endif - }; // end namespace spv diff --git a/3rdparty/glslang/SPIRV/disassemble.h b/3rdparty/glslang/SPIRV/disassemble.h old mode 100755 new mode 100644 index bdde5cb4e..2a9a89b53 --- a/3rdparty/glslang/SPIRV/disassemble.h +++ b/3rdparty/glslang/SPIRV/disassemble.h @@ -48,9 +48,6 @@ namespace spv { // disassemble with glslang custom disassembler void Disassemble(std::ostream& out, const std::vector&); - // disassemble with SPIRV-Tools disassembler - void SpirvToolsDisassemble(std::ostream& out, const std::vector& stream); - }; // end namespace spv #endif // disassembler_H diff --git a/3rdparty/glslang/SPIRV/doc.cpp b/3rdparty/glslang/SPIRV/doc.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/SPIRV/spirv.hpp b/3rdparty/glslang/SPIRV/spirv.hpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/SPIRV/spvIR.h b/3rdparty/glslang/SPIRV/spvIR.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/StandAlone/CMakeLists.txt b/3rdparty/glslang/StandAlone/CMakeLists.txt old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/StandAlone/StandAlone.cpp b/3rdparty/glslang/StandAlone/StandAlone.cpp old mode 100755 new mode 100644 index e49b26e5a..38ba597d7 --- a/3rdparty/glslang/StandAlone/StandAlone.cpp +++ b/3rdparty/glslang/StandAlone/StandAlone.cpp @@ -103,6 +103,7 @@ enum TOptions { }; bool targetHlslFunctionality1 = false; bool SpvToolsDisassembler = false; +bool SpvToolsValidate = false; // // Return codes from main/exit(). @@ -514,6 +515,8 @@ void ProcessArguments(std::vector>& workItem break; } else if (lowerword == "spirv-dis") { SpvToolsDisassembler = true; + } else if (lowerword == "spirv-val") { + SpvToolsValidate = true; } else if (lowerword == "stdin") { Options |= EOptionStdin; shaderStageName = argv[1]; @@ -978,6 +981,8 @@ void CompileAndLinkShaderUnits(std::vector compUnits) spvOptions.generateDebugInfo = true; spvOptions.disableOptimizer = (Options & EOptionOptimizeDisable) != 0; spvOptions.optimizeSize = (Options & EOptionOptimizeSize) != 0; + spvOptions.disassemble = SpvToolsDisassembler; + spvOptions.validate = SpvToolsValidate; glslang::GlslangToSpv(*program.getIntermediate((EShLanguage)stage), spirv, &logger, &spvOptions); // Dump the spv to a file or stdout, etc., but only if not doing @@ -989,13 +994,6 @@ void CompileAndLinkShaderUnits(std::vector compUnits) } else { glslang::OutputSpvBin(spirv, GetBinaryName((EShLanguage)stage)); } -#if ENABLE_OPT - if (SpvToolsDisassembler) - spv::SpirvToolsDisassemble(std::cout, spirv); -#else - if (SpvToolsDisassembler) - printf("SPIRV-Tools is not enabled; use -H for human readable SPIR-V\n"); -#endif if (!SpvToolsDisassembler && (Options & EOptionHumanReadableSpv)) spv::Disassemble(std::cout, spirv); } @@ -1427,6 +1425,7 @@ void usage() " --shift-cbuffer-binding | --scb synonyms for --shift-UBO-binding\n" " --spirv-dis output standard-form disassembly; works only\n" " when a SPIR-V generation option is also used\n" + " --spirv-val execute the SPIRV-Tools validator\n" " --source-entrypoint the given shader source function is\n" " renamed to be the given in -e\n" " --sep synonym for --source-entrypoint\n" diff --git a/3rdparty/glslang/Test/110scope.vert b/3rdparty/glslang/Test/110scope.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/400.geom b/3rdparty/glslang/Test/400.geom old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/420.comp b/3rdparty/glslang/Test/420.comp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/badMacroArgs.frag b/3rdparty/glslang/Test/badMacroArgs.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseLegalResults/hlsl.flattenSubset.frag.out b/3rdparty/glslang/Test/baseLegalResults/hlsl.flattenSubset.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseLegalResults/hlsl.flattenSubset2.frag.out b/3rdparty/glslang/Test/baseLegalResults/hlsl.flattenSubset2.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseLegalResults/hlsl.partialFlattenLocal.vert.out b/3rdparty/glslang/Test/baseLegalResults/hlsl.partialFlattenLocal.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseLegalResults/hlsl.partialFlattenMixed.vert.out b/3rdparty/glslang/Test/baseLegalResults/hlsl.partialFlattenMixed.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/310runtimeArray.vert.out b/3rdparty/glslang/Test/baseResults/310runtimeArray.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/320.comp.out b/3rdparty/glslang/Test/baseResults/320.comp.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/320.frag.out b/3rdparty/glslang/Test/baseResults/320.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/320.geom.out b/3rdparty/glslang/Test/baseResults/320.geom.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/320.tesc.out b/3rdparty/glslang/Test/baseResults/320.tesc.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/320.tese.out b/3rdparty/glslang/Test/baseResults/320.tese.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/320.vert.out b/3rdparty/glslang/Test/baseResults/320.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/400.vert.out b/3rdparty/glslang/Test/baseResults/400.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/410.vert.out b/3rdparty/glslang/Test/baseResults/410.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/420.comp.out b/3rdparty/glslang/Test/baseResults/420.comp.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/435.vert.out b/3rdparty/glslang/Test/baseResults/435.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/460.frag.out b/3rdparty/glslang/Test/baseResults/460.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/460.vert.out b/3rdparty/glslang/Test/baseResults/460.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/cppBad.vert.out b/3rdparty/glslang/Test/baseResults/cppBad.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/cppBad2.vert.out b/3rdparty/glslang/Test/baseResults/cppBad2.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/cppDeepNest.frag.out b/3rdparty/glslang/Test/baseResults/cppDeepNest.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/cppPassMacroName.frag.out b/3rdparty/glslang/Test/baseResults/cppPassMacroName.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/cppRelaxSkipTokensErrors.vert.out b/3rdparty/glslang/Test/baseResults/cppRelaxSkipTokensErrors.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/glspv.esversion.vert.out b/3rdparty/glslang/Test/baseResults/glspv.esversion.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/glspv.frag.out b/3rdparty/glslang/Test/baseResults/glspv.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/glspv.version.frag.out b/3rdparty/glslang/Test/baseResults/glspv.version.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/glspv.version.vert.out b/3rdparty/glslang/Test/baseResults/glspv.version.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/glspv.vert.out b/3rdparty/glslang/Test/baseResults/glspv.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.PointSize.geom.out b/3rdparty/glslang/Test/baseResults/hlsl.PointSize.geom.out old mode 100755 new mode 100644 index 0d18f1fde..c21008d28 --- a/3rdparty/glslang/Test/baseResults/hlsl.PointSize.geom.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.PointSize.geom.out @@ -69,6 +69,10 @@ output primitive = line_strip 0:? 'ps' ( in 3-element array of uint PointSize) 0:? 'OutputStream.ps' ( out float PointSize) +error: SPIRV-Tools Validation Errors +error: According to the Vulkan spec BuiltIn PointSize variable needs to be a 32-bit float scalar. ID <28> (OpVariable) is not a float scalar. + %29 = OpLoad %_arr_uint_uint_3 %ps_1 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 36 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.PointSize.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.PointSize.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.aliasOpaque.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.aliasOpaque.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.amend.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.amend.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.array.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.array.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.assoc.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.assoc.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.attribute.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.attribute.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.attributeC11.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.attributeC11.frag.out old mode 100755 new mode 100644 index becb50063..afc746699 --- a/3rdparty/glslang/Test/baseResults/hlsl.attributeC11.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.attributeC11.frag.out @@ -93,6 +93,10 @@ gl_FragCoord origin is upper left 0:? '@entryPointOutput' (layout( location=7) out 4-component vector of float) 0:? 'input' (layout( location=8) in 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Operand 2 of Decorate requires one of these capabilities: InputAttachment + OpDecorate %attach InputAttachmentIndex 4 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 51 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.attributeGlobalBuffer.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.attributeGlobalBuffer.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.basic.comp.out b/3rdparty/glslang/Test/baseResults/hlsl.basic.comp.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.boolConv.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.boolConv.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.buffer.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.buffer.frag.out old mode 100755 new mode 100644 index 8d2c51473..4528d98b0 --- a/3rdparty/glslang/Test/baseResults/hlsl.buffer.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.buffer.frag.out @@ -145,6 +145,10 @@ gl_FragCoord origin is upper left 0:? '@entryPointOutput.a' (layout( location=0) out 4-component vector of float) 0:? 'input' ( in 4-component vector of float FragCoord) +error: SPIRV-Tools Validation Errors +error: Structure id 50 decorated as BufferBlock for variable in Uniform storage class must follow standard storage buffer layout rules: member 7 at offset 128 overlaps previous member ending at offset 171 + %tbufName = OpTypeStruct %v4float %int %float %float %float %float %float %float %mat3v4float %mat3v4float %mat3v4float %mat3v4float + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 73 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.cast.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.cast.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.charLit.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.charLit.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.conditional.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.conditional.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.constantbuffer.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.constantbuffer.frag.out index 4b5c6b1c8..4185ea95e 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.constantbuffer.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.constantbuffer.frag.out @@ -131,6 +131,10 @@ gl_FragCoord origin is upper left 0:? 'anon@0' (layout( row_major std140) uniform block{layout( row_major std140) uniform int c1}) 0:? '@entryPointOutput' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Only a single level of array is allowed for descriptor set variables + %cb3_0 = OpVariable %_ptr_Uniform__arr__arr_cb3_uint_4_uint_2 Uniform + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 66 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.constructArray.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.constructArray.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.constructimat.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.constructimat.frag.out index c36ff6d0e..e88c3d8f3 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.constructimat.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.constructimat.frag.out @@ -543,6 +543,10 @@ gl_FragCoord origin is upper left 0:? Linker Objects 0:? '@entryPointOutput' (layout( location=0) out int) +error: SPIRV-Tools Validation Errors +error: Matrix types can only be parameterized with floating-point types. + %mat4v4int = OpTypeMatrix %v4int 4 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 98 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.coverage.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.coverage.frag.out index 8afc59af4..bea2fc0e5 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.coverage.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.coverage.frag.out @@ -117,6 +117,10 @@ gl_FragCoord origin is upper left 0:? '@entryPointOutput.nCoverageMask' ( out 1-element array of uint SampleMaskIn) 0:? '@entryPointOutput.vColor' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Input variable id <34> is used by entry point 'main' id <4>, but is not listed as an interface + %i_1 = OpVariable %_ptr_Input_PS_INPUT Input + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 52 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.depthGreater.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.depthGreater.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.depthLess.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.depthLess.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.discard.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.discard.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.doLoop.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.doLoop.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.emptystructreturn.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.emptystructreturn.frag.out index 8c8b62bb1..34a635c70 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.emptystructreturn.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.emptystructreturn.frag.out @@ -49,6 +49,10 @@ gl_FragCoord origin is upper left 0:? 'i' ( temp structure{}) 0:? Linker Objects +error: SPIRV-Tools Validation Errors +error: Input variable id <20> is used by entry point 'main' id <4>, but is not listed as an interface + %i_1 = OpVariable %_ptr_Input_ps_in Input + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 27 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.emptystructreturn.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.emptystructreturn.vert.out index b2aaf5ee7..61704586f 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.emptystructreturn.vert.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.emptystructreturn.vert.out @@ -47,6 +47,10 @@ Shader version: 500 0:? 'i' ( temp structure{}) 0:? Linker Objects +error: SPIRV-Tools Validation Errors +error: Input variable id <20> is used by entry point 'main' id <4>, but is not listed as an interface + %i_1 = OpVariable %_ptr_Input_vs_in Input + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 27 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.entry-in.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.entry-in.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.entry-out.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.entry-out.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.flattenOpaque.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.flattenOpaque.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.flattenOpaqueInit.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.flattenOpaqueInit.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.flattenOpaqueInitMix.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.flattenOpaqueInitMix.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.flattenSubset.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.flattenSubset.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.flattenSubset2.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.flattenSubset2.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.float1.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.float1.frag.out old mode 100755 new mode 100644 index 786212313..31febfd49 --- a/3rdparty/glslang/Test/baseResults/hlsl.float1.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.float1.frag.out @@ -64,6 +64,10 @@ gl_FragCoord origin is upper left 0:? 'f1' ( global 1-component vector of float) 0:? 'scalar' ( global float) +error: SPIRV-Tools Validation Errors +error: Expected int scalar or vector type as Result Type: IMul + %20 = OpIMul %float %18 %19 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 27 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.float4.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.float4.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.forLoop.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.forLoop.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.function.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.function.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.gatherRGBA.offset.dx10.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.gatherRGBA.offset.dx10.frag.out index 49fda31a0..33c9af43e 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.gatherRGBA.offset.dx10.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.gatherRGBA.offset.dx10.frag.out @@ -1261,6 +1261,10 @@ using depth_any 0:? '@entryPointOutput.Depth' ( out float FragDepth) 0:? '@entryPointOutput.Color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Expected Image Operand ConstOffsets to be a const object + %90 = OpImageGather %v4float %76 %78 %int_0 ConstOffsets %89 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 399 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.gatherRGBA.offsetarray.dx10.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.gatherRGBA.offsetarray.dx10.frag.out index 9de1a9721..22b02e7f4 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.gatherRGBA.offsetarray.dx10.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.gatherRGBA.offsetarray.dx10.frag.out @@ -1253,6 +1253,10 @@ using depth_any 0:? '@entryPointOutput.Depth' ( out float FragDepth) 0:? '@entryPointOutput.Color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Expected Image Operand ConstOffsets to be a const object + %90 = OpImageGather %v4float %76 %78 %int_0 ConstOffsets %89 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 389 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.hull.3.tesc.out b/3rdparty/glslang/Test/baseResults/hlsl.hull.3.tesc.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.if.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.if.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.implicitBool.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.implicitBool.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.include.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.include.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.includeNegative.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.includeNegative.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.inf.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.inf.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.init.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.init.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.comp.out b/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.comp.out index a5b543cb2..5058f2365 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.comp.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.comp.out @@ -715,6 +715,10 @@ local_size = (1, 1, 1) 0:? 'inU0' (layout( location=3) in 4-component vector of uint) 0:? 'inU1' (layout( location=4) in 4-component vector of uint) +error: SPIRV-Tools Validation Errors +error: Expected operand to be vector bool: All + %64 = OpAll %bool %63 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 265 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.evalfns.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.evalfns.frag.out index e7865627e..4fd1e7b38 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.evalfns.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.evalfns.frag.out @@ -153,6 +153,10 @@ gl_FragCoord origin is upper left 0:? 'inF4' (layout( location=3) in 4-component vector of float) 0:? 'inI2' (layout( location=4) flat in 2-component vector of int) +error: SPIRV-Tools Validation Errors +error: GLSL.std.450 InterpolateAtOffset: expected Interpolant storage class to be Input + %28 = OpExtInst %float %1 InterpolateAtOffset %inF1 %27 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 80 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.frag.out index 20eb0321b..20d2bb047 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.frag.out @@ -5643,6 +5643,10 @@ gl_FragCoord origin is upper left 0:? 'gs_uc4' ( shared 4-component vector of uint) 0:? '@entryPointOutput.color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Matrix types can only be parameterized with floating-point types. + %mat2v2bool = OpTypeMatrix %v2bool 2 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 1836 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.vert.out index 8e7e3ec61..195e11d6e 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.vert.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.intrinsics.vert.out @@ -2778,6 +2778,10 @@ Shader version: 500 0:413 'inFM3x2' ( in 3X2 matrix of float) 0:? Linker Objects +error: SPIRV-Tools Validation Errors +error: Matrix types can only be parameterized with floating-point types. + %mat2v2bool = OpTypeMatrix %v2bool 2 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 1225 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.layout.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.layout.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.layoutOverride.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.layoutOverride.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.logicalConvert.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.logicalConvert.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.matNx1.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.matNx1.frag.out index 109362e45..276d4c249 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.matNx1.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.matNx1.frag.out @@ -151,6 +151,10 @@ gl_FragCoord origin is upper left 0:? Linker Objects 0:? '@entryPointOutput.color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Illegal number of components (1) for TypeVector + %v1float = OpTypeVector %float 1 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 77 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.matType.bool.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.matType.bool.frag.out index 82575b04a..900c60fcd 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.matType.bool.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.matType.bool.frag.out @@ -231,6 +231,10 @@ gl_FragCoord origin is upper left 0:? Linker Objects 0:? '@entryPointOutput.color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Illegal number of components (1) for TypeVector + %v1bool = OpTypeVector %bool 1 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 130 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.matType.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.matType.frag.out old mode 100755 new mode 100644 index 958b37e0d..c0d2e4b3f --- a/3rdparty/glslang/Test/baseResults/hlsl.matType.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.matType.frag.out @@ -30,6 +30,10 @@ gl_FragCoord origin is upper left 0:? Linker Objects 0:? 'anon@0' (layout( row_major std140) uniform block{ uniform 1-component vector of float f1, uniform 1X1 matrix of float fmat11, uniform 4X1 matrix of float fmat41, uniform 1X2 matrix of float fmat12, uniform 2X3 matrix of double dmat23, uniform 4X4 matrix of int int44}) +error: SPIRV-Tools Validation Errors +error: Illegal number of components (1) for TypeVector + %v1float = OpTypeVector %float 1 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 30 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.matType.int.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.matType.int.frag.out index b8d29ac05..2039dfd57 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.matType.int.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.matType.int.frag.out @@ -397,6 +397,10 @@ gl_FragCoord origin is upper left 0:? Linker Objects 0:? '@entryPointOutput.color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Illegal number of components (1) for TypeVector + %v1int = OpTypeVector %int 1 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 232 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.matrixSwizzle.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.matrixSwizzle.vert.out old mode 100755 new mode 100644 index 9bf7e5604..abb3e495a --- a/3rdparty/glslang/Test/baseResults/hlsl.matrixSwizzle.vert.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.matrixSwizzle.vert.out @@ -676,6 +676,10 @@ Shader version: 500 0:? 'inf' (layout( location=0) in float) Missing functionality: matrix swizzle +error: SPIRV-Tools Validation Errors +error: OpStore Pointer '42[f3]'s type does not match Object '34's type. + OpStore %f3 %int_0 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 118 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.max.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.max.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.memberFunCall.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.memberFunCall.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.multiEntry.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.multiEntry.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.multiReturn.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.multiReturn.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.namespace.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.namespace.frag.out old mode 100755 new mode 100644 index bfb82dadc..08d959b3e --- a/3rdparty/glslang/Test/baseResults/hlsl.namespace.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.namespace.frag.out @@ -101,6 +101,10 @@ gl_FragCoord origin is upper left 0:? 'N2::gf' ( global float) 0:? '@entryPointOutput' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: OpFunctionCall Function 's parameter count does not match the argument count. + %43 = OpFunctionCall %v4float %N2__N3__C1__getVec_ + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 54 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.nonstaticMemberFunction.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.nonstaticMemberFunction.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.overload.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.overload.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.partialFlattenLocal.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.partialFlattenLocal.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.partialFlattenMixed.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.partialFlattenMixed.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.partialInit.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.partialInit.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.pp.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.pp.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.precedence.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.precedence.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.precedence2.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.precedence2.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.samplebias.offset.dx10.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.samplebias.offset.dx10.frag.out index 73d69dcb3..ae492e4d4 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.samplebias.offset.dx10.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.samplebias.offset.dx10.frag.out @@ -399,6 +399,10 @@ using depth_any 0:? '@entryPointOutput.Depth' ( out float FragDepth) 0:? '@entryPointOutput.Color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Expected Image Operand Bias to be float scalar + %28 = OpImageSampleImplicitLod %v4float %23 %float_0_100000001 Bias|ConstOffset %int_1 %float_0_5 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 161 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.samplebias.offsetarray.dx10.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.samplebias.offsetarray.dx10.frag.out index 0a7a66b2e..0cea76706 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.samplebias.offsetarray.dx10.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.samplebias.offsetarray.dx10.frag.out @@ -297,6 +297,10 @@ using depth_any 0:? '@entryPointOutput.Depth' ( out float FragDepth) 0:? '@entryPointOutput.Color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Expected Image Operand Bias to be float scalar + %31 = OpImageSampleImplicitLod %v4float %23 %27 Bias|ConstOffset %int_0 %float_0_5 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 118 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.array.dx10.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.array.dx10.frag.out index f8f20ca25..a41481da5 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.array.dx10.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.array.dx10.frag.out @@ -397,6 +397,10 @@ using depth_any 0:? '@entryPointOutput.Depth' ( out float FragDepth) 0:? '@entryPointOutput.Color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Expected Image 'Sampled Type' to be the same as Result Type + %48 = OpImageSampleDrefImplicitLod %float %43 %46 %47 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 209 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.basic.dx10.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.basic.dx10.frag.out index 9862297fb..e8252d381 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.basic.dx10.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.basic.dx10.frag.out @@ -379,6 +379,10 @@ using depth_any 0:? '@entryPointOutput.Depth' ( out float FragDepth) 0:? '@entryPointOutput.Color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Expected Image 'Sampled Type' to be the same as Result Type + %41 = OpImageSampleDrefImplicitLod %float %38 %39 %40 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 198 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.offset.dx10.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.offset.dx10.frag.out index f0ba44492..cb4ce3909 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.offset.dx10.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.offset.dx10.frag.out @@ -325,6 +325,10 @@ using depth_any 0:? '@entryPointOutput.Depth' ( out float FragDepth) 0:? '@entryPointOutput.Color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Expected Image 'Sampled Type' to be the same as Result Type + %42 = OpImageSampleDrefImplicitLod %float %39 %40 %41 ConstOffset %int_2 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 167 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.offsetarray.dx10.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.offsetarray.dx10.frag.out index ae6078cac..af2af3f29 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.offsetarray.dx10.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.samplecmp.offsetarray.dx10.frag.out @@ -337,6 +337,10 @@ using depth_any 0:? '@entryPointOutput.Depth' ( out float FragDepth) 0:? '@entryPointOutput.Color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Expected Image 'Sampled Type' to be the same as Result Type + %49 = OpImageSampleDrefImplicitLod %float %44 %47 %48 ConstOffset %int_2 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 178 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.array.dx10.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.array.dx10.frag.out index ae5b11896..a0e5a487f 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.array.dx10.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.array.dx10.frag.out @@ -433,6 +433,10 @@ using depth_any 0:? '@entryPointOutput.Depth' ( out float FragDepth) 0:? '@entryPointOutput.Color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Expected Image 'Sampled Type' to be the same as Result Type + %49 = OpImageSampleDrefExplicitLod %float %44 %47 %48 Lod %float_0 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 210 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.basic.dx10.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.basic.dx10.frag.out index 53ecbf2e0..ffe22988e 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.basic.dx10.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.basic.dx10.frag.out @@ -415,6 +415,10 @@ using depth_any 0:? '@entryPointOutput.Depth' ( out float FragDepth) 0:? '@entryPointOutput.Color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Expected Image 'Sampled Type' to be the same as Result Type + %42 = OpImageSampleDrefExplicitLod %float %39 %40 %41 Lod %float_0 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 199 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.offset.dx10.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.offset.dx10.frag.out index 1d4f2cdec..08e8749f9 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.offset.dx10.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.offset.dx10.frag.out @@ -349,6 +349,10 @@ using depth_any 0:? '@entryPointOutput.Depth' ( out float FragDepth) 0:? '@entryPointOutput.Color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Expected Image 'Sampled Type' to be the same as Result Type + %43 = OpImageSampleDrefExplicitLod %float %40 %41 %42 Lod|ConstOffset %float_0 %int_2 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 168 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.offsetarray.dx10.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.offsetarray.dx10.frag.out index dea666337..b5c0fc0d7 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.offsetarray.dx10.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.samplecmplevelzero.offsetarray.dx10.frag.out @@ -361,6 +361,10 @@ using depth_any 0:? '@entryPointOutput.Depth' ( out float FragDepth) 0:? '@entryPointOutput.Color' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: Expected Image 'Sampled Type' to be the same as Result Type + %50 = OpImageSampleDrefExplicitLod %float %45 %48 %49 Lod|ConstOffset %float_0 %int_2 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 179 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.scalarCast.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.scalarCast.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.scope.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.scope.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.semantic.geom.out b/3rdparty/glslang/Test/baseResults/hlsl.semantic.geom.out old mode 100755 new mode 100644 index 1c1a9c0af..e73940bce --- a/3rdparty/glslang/Test/baseResults/hlsl.semantic.geom.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.semantic.geom.out @@ -155,6 +155,10 @@ output primitive = line_strip 0:? 'OutputStream.clip0' ( out 1-element array of float ClipDistance) 0:? 'OutputStream.cull0' ( out 1-element array of float CullDistance) +error: SPIRV-Tools Validation Errors +error: According to the Vulkan spec BuiltIn Position variable needs to be a 4-component 32-bit float vector. ID <20> (OpVariable) is not a float vector. + OpStore %OutputStream_clip0 %25 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 65 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.semantic.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.semantic.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.shapeConv.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.shapeConv.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.shapeConvRet.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.shapeConvRet.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.sin.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.sin.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.staticFuncInit.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.staticFuncInit.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.staticMemberFunction.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.staticMemberFunction.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.string.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.string.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.struct.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.struct.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.struct.split.assign.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.struct.split.assign.frag.out index 3454eb6d5..0598ac9c4 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.struct.split.assign.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.struct.split.assign.frag.out @@ -207,6 +207,10 @@ gl_FragCoord origin is upper left 0:? 'input[1].f' (layout( location=2) in float) 0:? 'input[2].f' (layout( location=3) in float) +error: SPIRV-Tools Validation Errors +error: According to the Vulkan spec BuiltIn FragCoord variable needs to be a 4-component 32-bit float vector. ID <41> (OpVariable) is not a float vector. + %input_pos = OpVariable %_ptr_Input__arr_v4float_uint_3 Input + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 66 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.structIoFourWay.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.structIoFourWay.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.structStructName.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.structStructName.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.append.fn.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.append.fn.frag.out index 9beadc722..df0865565 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.append.fn.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.append.fn.frag.out @@ -149,6 +149,10 @@ gl_FragCoord origin is upper left 0:? '@entryPointOutput' (layout( location=0) out 4-component vector of float) 0:? 'pos' (layout( location=0) flat in uint) +error: SPIRV-Tools Validation Errors +error: Structure id 12 decorated as BufferBlock must be explicitly laid out with Offset decorations. + %__0 = OpTypeStruct %uint + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 70 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.atomics.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.atomics.frag.out index d78f77ed3..68d93c187 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.atomics.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.atomics.frag.out @@ -473,6 +473,10 @@ gl_FragCoord origin is upper left 0:? '@entryPointOutput' (layout( location=0) out 4-component vector of float) 0:? 'pos' (layout( location=0) flat in uint) +error: SPIRV-Tools Validation Errors +error: AtomicIAdd: expected Value to be of type Result Type + %28 = OpAtomicIAdd %uint %24 %uint_1 %uint_0 %int_1 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 87 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.byte.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.byte.frag.out index 862ebbef0..49958a658 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.byte.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.byte.frag.out @@ -323,6 +323,10 @@ gl_FragCoord origin is upper left 0:? '@entryPointOutput' (layout( location=0) out 4-component vector of float) 0:? 'pos' (layout( location=0) flat in uint) +error: SPIRV-Tools Validation Errors +error: OpStore Pointer '14[size]'s type does not match Object '20's type. + OpStore %size %20 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 114 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.coherent.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.coherent.frag.out index 18de2a8d5..1d11b6470 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.coherent.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.coherent.frag.out @@ -175,6 +175,10 @@ gl_FragCoord origin is upper left 0:? '@entryPointOutput' (layout( location=0) out 4-component vector of float) 0:? 'pos' (layout( location=0) flat in uint) +error: SPIRV-Tools Validation Errors +error: OpStore Pointer '26[size]'s type does not match Object '33's type. + OpStore %size %33 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 78 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.fn.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.fn.frag.out index 4bbc550d8..4b8ee635e 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.fn.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.fn.frag.out @@ -137,6 +137,10 @@ gl_FragCoord origin is upper left 0:? '@entryPointOutput' (layout( location=0) out 4-component vector of float) 0:? 'pos' (layout( location=0) flat in uint) +error: SPIRV-Tools Validation Errors +error: Structure id 20 decorated as BufferBlock must be explicitly laid out with Offset decorations. + %__1 = OpTypeStruct %uint + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 78 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.frag.out index 5f6e8eec5..e058d1124 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.frag.out @@ -187,6 +187,10 @@ gl_FragCoord origin is upper left 0:? '@entryPointOutput' (layout( location=0) out 4-component vector of float) 0:? 'pos' (layout( location=0) flat in uint) +error: SPIRV-Tools Validation Errors +error: OpStore Pointer '43[size]'s type does not match Object '44's type. + OpStore %size %44 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 96 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.rw.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.rw.frag.out index ccf295bdd..7fbd1502d 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.rw.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.rw.frag.out @@ -175,6 +175,10 @@ gl_FragCoord origin is upper left 0:? '@entryPointOutput' (layout( location=0) out 4-component vector of float) 0:? 'pos' (layout( location=0) flat in uint) +error: SPIRV-Tools Validation Errors +error: OpStore Pointer '26[size]'s type does not match Object '33's type. + OpStore %size %33 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 78 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.rwbyte.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.rwbyte.frag.out index 9f1b5b325..ed27c8939 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.rwbyte.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.structbuffer.rwbyte.frag.out @@ -1003,6 +1003,10 @@ gl_FragCoord origin is upper left 0:? '@entryPointOutput' (layout( location=0) out 4-component vector of float) 0:? 'pos' (layout( location=0) flat in uint) +error: SPIRV-Tools Validation Errors +error: OpStore Pointer '14[size]'s type does not match Object '20's type. + OpStore %size %20 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 239 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.structin.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.structin.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.switch.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.switch.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.swizzle.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.swizzle.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.synthesizeInput.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.synthesizeInput.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.target.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.target.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.targetStruct1.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.targetStruct1.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.targetStruct2.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.targetStruct2.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.texture.struct.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.texture.struct.frag.out index 0778f5061..62cb57497 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.texture.struct.frag.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.texture.struct.frag.out @@ -837,6 +837,10 @@ gl_FragCoord origin is upper left 0:? 'g_tTex2s1a' ( uniform texture2D) 0:? '@entryPointOutput' (layout( location=0) out 4-component vector of float) +error: SPIRV-Tools Validation Errors +error: OpStore Pointer '185's type does not match Object '184's type. + OpStore %185 %184 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 240 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.this.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.this.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.tristream-append.geom.out b/3rdparty/glslang/Test/baseResults/hlsl.tristream-append.geom.out index be6ca9cd6..c11672495 100644 --- a/3rdparty/glslang/Test/baseResults/hlsl.tristream-append.geom.out +++ b/3rdparty/glslang/Test/baseResults/hlsl.tristream-append.geom.out @@ -105,6 +105,10 @@ output primitive = triangle_strip 0:? 'TriStream' ( temp structure{}) 0:? Linker Objects +error: SPIRV-Tools Validation Errors +error: Output variable id <23> is used by entry point 'main' id <4>, but is not listed as an interface + %TriStream_1 = OpVariable %_ptr_Output_GSPS_INPUT Output + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 57 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.typeGraphCopy.vert.out b/3rdparty/glslang/Test/baseResults/hlsl.typeGraphCopy.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.typedef.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.typedef.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.void.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.void.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/hlsl.whileLoop.frag.out b/3rdparty/glslang/Test/baseResults/hlsl.whileLoop.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/implicitInnerAtomicUint.frag.out b/3rdparty/glslang/Test/baseResults/implicitInnerAtomicUint.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/missingBodies.vert.out b/3rdparty/glslang/Test/baseResults/missingBodies.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/mixedArrayDecls.frag.out b/3rdparty/glslang/Test/baseResults/mixedArrayDecls.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/nonuniform.frag.out b/3rdparty/glslang/Test/baseResults/nonuniform.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/runtimeArray.vert.out b/3rdparty/glslang/Test/baseResults/runtimeArray.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/specExamplesConf.vert.out b/3rdparty/glslang/Test/baseResults/specExamplesConf.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.100ops.frag.out b/3rdparty/glslang/Test/baseResults/spv.100ops.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.130.frag.out b/3rdparty/glslang/Test/baseResults/spv.130.frag.out index d1a626d8d..eb02bade5 100644 --- a/3rdparty/glslang/Test/baseResults/spv.130.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.130.frag.out @@ -1,6 +1,10 @@ spv.130.frag WARNING: 0:31: '#extension' : extension is only partially supported: GL_ARB_gpu_shader5 +error: SPIRV-Tools Validation Errors +error: Capability SampledRect is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability SampledRect + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 205 diff --git a/3rdparty/glslang/Test/baseResults/spv.140.frag.out b/3rdparty/glslang/Test/baseResults/spv.140.frag.out old mode 100755 new mode 100644 index 89bf4899b..8a59c2f9b --- a/3rdparty/glslang/Test/baseResults/spv.140.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.140.frag.out @@ -1,4 +1,8 @@ spv.140.frag +error: SPIRV-Tools Validation Errors +error: Capability SampledRect is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability SampledRect + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 96 diff --git a/3rdparty/glslang/Test/baseResults/spv.150.geom.out b/3rdparty/glslang/Test/baseResults/spv.150.geom.out old mode 100755 new mode 100644 index f75979383..70dadf5de --- a/3rdparty/glslang/Test/baseResults/spv.150.geom.out +++ b/3rdparty/glslang/Test/baseResults/spv.150.geom.out @@ -1,4 +1,8 @@ spv.150.geom +error: SPIRV-Tools Validation Errors +error: Capability GeometryStreams is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability GeometryStreams + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 71 diff --git a/3rdparty/glslang/Test/baseResults/spv.150.vert.out b/3rdparty/glslang/Test/baseResults/spv.150.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.16bitstorage-int.frag.out b/3rdparty/glslang/Test/baseResults/spv.16bitstorage-int.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.16bitstorage-uint.frag.out b/3rdparty/glslang/Test/baseResults/spv.16bitstorage-uint.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.16bitstorage.frag.out b/3rdparty/glslang/Test/baseResults/spv.16bitstorage.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.16bitstorage_Error-int.frag.out b/3rdparty/glslang/Test/baseResults/spv.16bitstorage_Error-int.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.16bitstorage_Error-uint.frag.out b/3rdparty/glslang/Test/baseResults/spv.16bitstorage_Error-uint.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.16bitstorage_Error.frag.out b/3rdparty/glslang/Test/baseResults/spv.16bitstorage_Error.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.300BuiltIns.vert.out b/3rdparty/glslang/Test/baseResults/spv.300BuiltIns.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.300layout.frag.out b/3rdparty/glslang/Test/baseResults/spv.300layout.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.300layoutp.vert.out b/3rdparty/glslang/Test/baseResults/spv.300layoutp.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.310.bitcast.frag.out b/3rdparty/glslang/Test/baseResults/spv.310.bitcast.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.400.frag.out b/3rdparty/glslang/Test/baseResults/spv.400.frag.out index 9cb2c63a1..a0583cff8 100644 --- a/3rdparty/glslang/Test/baseResults/spv.400.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.400.frag.out @@ -1,4 +1,8 @@ spv.400.frag +error: SPIRV-Tools Validation Errors +error: Capability SampledRect is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability SampledRect + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 1118 diff --git a/3rdparty/glslang/Test/baseResults/spv.400.tese.out b/3rdparty/glslang/Test/baseResults/spv.400.tese.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.420.geom.out b/3rdparty/glslang/Test/baseResults/spv.420.geom.out index 74a4f0b31..45f235f4b 100644 --- a/3rdparty/glslang/Test/baseResults/spv.420.geom.out +++ b/3rdparty/glslang/Test/baseResults/spv.420.geom.out @@ -1,4 +1,8 @@ spv.420.geom +error: SPIRV-Tools Validation Errors +error: Capability GeometryStreams is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability GeometryStreams + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 72 diff --git a/3rdparty/glslang/Test/baseResults/spv.430.frag.out b/3rdparty/glslang/Test/baseResults/spv.430.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.430.vert.out b/3rdparty/glslang/Test/baseResults/spv.430.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.450.geom.out b/3rdparty/glslang/Test/baseResults/spv.450.geom.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.450.noRedecl.tesc.out b/3rdparty/glslang/Test/baseResults/spv.450.noRedecl.tesc.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.450.tesc.out b/3rdparty/glslang/Test/baseResults/spv.450.tesc.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.460.comp.out b/3rdparty/glslang/Test/baseResults/spv.460.comp.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.460.frag.out b/3rdparty/glslang/Test/baseResults/spv.460.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.460.vert.out b/3rdparty/glslang/Test/baseResults/spv.460.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.8bitstorage-int.frag.out b/3rdparty/glslang/Test/baseResults/spv.8bitstorage-int.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.8bitstorage-uint.frag.out b/3rdparty/glslang/Test/baseResults/spv.8bitstorage-uint.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.8bitstorage_Error-int.frag.out b/3rdparty/glslang/Test/baseResults/spv.8bitstorage_Error-int.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.8bitstorage_Error-uint.frag.out b/3rdparty/glslang/Test/baseResults/spv.8bitstorage_Error-uint.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.AofA.frag.out b/3rdparty/glslang/Test/baseResults/spv.AofA.frag.out index a19fae921..798f083af 100644 --- a/3rdparty/glslang/Test/baseResults/spv.AofA.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.AofA.frag.out @@ -1,6 +1,10 @@ spv.AofA.frag WARNING: 0:6: '[][]' : Generating SPIR-V array-of-arrays, but Vulkan only supports single array level for this resource +error: SPIRV-Tools Validation Errors +error: Only a single level of array is allowed for descriptor set variables + %nameAofA = OpVariable %_ptr_Uniform__arr__arr_uAofA_uint_5_uint_3 Uniform + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 104 diff --git a/3rdparty/glslang/Test/baseResults/spv.Operations.frag.out b/3rdparty/glslang/Test/baseResults/spv.Operations.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.accessChain.frag.out b/3rdparty/glslang/Test/baseResults/spv.accessChain.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.always-discard2.frag.out b/3rdparty/glslang/Test/baseResults/spv.always-discard2.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.atomic.comp.out b/3rdparty/glslang/Test/baseResults/spv.atomic.comp.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.atomicInt64.comp.out b/3rdparty/glslang/Test/baseResults/spv.atomicInt64.comp.out index 9c66aecc8..a273c6694 100644 --- a/3rdparty/glslang/Test/baseResults/spv.atomicInt64.comp.out +++ b/3rdparty/glslang/Test/baseResults/spv.atomicInt64.comp.out @@ -1,4 +1,8 @@ spv.atomicInt64.comp +error: SPIRV-Tools Validation Errors +error: Capability Int64Atomics is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability Int64Atomics + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 149 diff --git a/3rdparty/glslang/Test/baseResults/spv.barrier.vert.out b/3rdparty/glslang/Test/baseResults/spv.barrier.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.builtInXFB.vert.out b/3rdparty/glslang/Test/baseResults/spv.builtInXFB.vert.out old mode 100755 new mode 100644 index 556a698ca..f175a19fa --- a/3rdparty/glslang/Test/baseResults/spv.builtInXFB.vert.out +++ b/3rdparty/glslang/Test/baseResults/spv.builtInXFB.vert.out @@ -1,4 +1,8 @@ spv.builtInXFB.vert +error: SPIRV-Tools Validation Errors +error: Capability TransformFeedback is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability TransformFeedback + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 21 diff --git a/3rdparty/glslang/Test/baseResults/spv.conditionalDiscard.frag.out b/3rdparty/glslang/Test/baseResults/spv.conditionalDiscard.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.constStruct.vert.out b/3rdparty/glslang/Test/baseResults/spv.constStruct.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.controlFlowAttributes.frag.out b/3rdparty/glslang/Test/baseResults/spv.controlFlowAttributes.frag.out old mode 100755 new mode 100644 index eb253822e..2f074def2 --- a/3rdparty/glslang/Test/baseResults/spv.controlFlowAttributes.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.controlFlowAttributes.frag.out @@ -7,6 +7,8 @@ WARNING: 0:24: '' : attribute with arguments not recognized, skipping WARNING: 0:25: '' : attribute with arguments not recognized, skipping WARNING: 0:26: '' : attribute with arguments not recognized, skipping +error: SPIRV-Tools Validation Errors +error: Invalid loop control operand: 4 has invalid mask component 4 // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 118 diff --git a/3rdparty/glslang/Test/baseResults/spv.conversion.frag.out b/3rdparty/glslang/Test/baseResults/spv.conversion.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.dataOut.frag.out b/3rdparty/glslang/Test/baseResults/spv.dataOut.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.dataOutIndirect.frag.out b/3rdparty/glslang/Test/baseResults/spv.dataOutIndirect.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.dataOutIndirect.vert.out b/3rdparty/glslang/Test/baseResults/spv.dataOutIndirect.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.debugInfo.1.1.frag.out b/3rdparty/glslang/Test/baseResults/spv.debugInfo.1.1.frag.out index facaf9e9c..7ba005225 100644 --- a/3rdparty/glslang/Test/baseResults/spv.debugInfo.1.1.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.debugInfo.1.1.frag.out @@ -1,4 +1,6 @@ spv.debugInfo.frag +error: SPIRV-Tools Validation Errors +error: Invalid SPIR-V binary version 1.3 for target environment SPIR-V 1.0 (under OpenGL 4.5 semantics). // Module Version 10300 // Generated by (magic number): 80007 // Id's are bound by 124 diff --git a/3rdparty/glslang/Test/baseResults/spv.depthOut.frag.out b/3rdparty/glslang/Test/baseResults/spv.depthOut.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.discard-dce.frag.out b/3rdparty/glslang/Test/baseResults/spv.discard-dce.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.do-simple.vert.out b/3rdparty/glslang/Test/baseResults/spv.do-simple.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.doWhileLoop.frag.out b/3rdparty/glslang/Test/baseResults/spv.doWhileLoop.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.double.comp.out b/3rdparty/glslang/Test/baseResults/spv.double.comp.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.earlyReturnDiscard.frag.out b/3rdparty/glslang/Test/baseResults/spv.earlyReturnDiscard.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.explicittypes.frag.out b/3rdparty/glslang/Test/baseResults/spv.explicittypes.frag.out old mode 100755 new mode 100644 index 6f7f2b9ef..44f5ddd3b --- a/3rdparty/glslang/Test/baseResults/spv.explicittypes.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.explicittypes.frag.out @@ -1,4 +1,8 @@ spv.explicittypes.frag +error: SPIRV-Tools Validation Errors +error: Capability Float16 is not allowed by Vulkan 1.1 specification (or requires extension) + OpCapability Float16 + // Module Version 10300 // Generated by (magic number): 80007 // Id's are bound by 576 diff --git a/3rdparty/glslang/Test/baseResults/spv.float16.frag.out b/3rdparty/glslang/Test/baseResults/spv.float16.frag.out index b6d37f455..b9ca3397b 100644 --- a/3rdparty/glslang/Test/baseResults/spv.float16.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.float16.frag.out @@ -1,4 +1,8 @@ spv.float16.frag +error: SPIRV-Tools Validation Errors +error: Capability Float16 is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability Float16 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 534 diff --git a/3rdparty/glslang/Test/baseResults/spv.float16Fetch.frag.out b/3rdparty/glslang/Test/baseResults/spv.float16Fetch.frag.out index 67ddd61b9..857ca6fac 100644 --- a/3rdparty/glslang/Test/baseResults/spv.float16Fetch.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.float16Fetch.frag.out @@ -1,4 +1,8 @@ spv.float16Fetch.frag +error: SPIRV-Tools Validation Errors +error: Capability Float16 is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability Float16 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 5923 diff --git a/3rdparty/glslang/Test/baseResults/spv.float32.frag.out b/3rdparty/glslang/Test/baseResults/spv.float32.frag.out index 40c6677f7..9ee7d7f93 100644 --- a/3rdparty/glslang/Test/baseResults/spv.float32.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.float32.frag.out @@ -1,4 +1,8 @@ spv.float32.frag +error: SPIRV-Tools Validation Errors +error: Capability Float16 is not allowed by Vulkan 1.1 specification (or requires extension) + OpCapability Float16 + // Module Version 10300 // Generated by (magic number): 80007 // Id's are bound by 533 diff --git a/3rdparty/glslang/Test/baseResults/spv.float64.frag.out b/3rdparty/glslang/Test/baseResults/spv.float64.frag.out index 491395488..3f095c1f7 100644 --- a/3rdparty/glslang/Test/baseResults/spv.float64.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.float64.frag.out @@ -1,4 +1,8 @@ spv.float64.frag +error: SPIRV-Tools Validation Errors +error: Capability Float16 is not allowed by Vulkan 1.1 specification (or requires extension) + OpCapability Float16 + // Module Version 10300 // Generated by (magic number): 80007 // Id's are bound by 524 diff --git a/3rdparty/glslang/Test/baseResults/spv.flowControl.frag.out b/3rdparty/glslang/Test/baseResults/spv.flowControl.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.for-simple.vert.out b/3rdparty/glslang/Test/baseResults/spv.for-simple.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.forLoop.frag.out b/3rdparty/glslang/Test/baseResults/spv.forLoop.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.forwardFun.frag.out b/3rdparty/glslang/Test/baseResults/spv.forwardFun.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.functionCall.frag.out b/3rdparty/glslang/Test/baseResults/spv.functionCall.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.functionSemantics.frag.out b/3rdparty/glslang/Test/baseResults/spv.functionSemantics.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.glFragColor.frag.out b/3rdparty/glslang/Test/baseResults/spv.glFragColor.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.image.frag.out b/3rdparty/glslang/Test/baseResults/spv.image.frag.out index 2c35a0c2f..2f925b3f6 100644 --- a/3rdparty/glslang/Test/baseResults/spv.image.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.image.frag.out @@ -1,4 +1,8 @@ spv.image.frag +error: SPIRV-Tools Validation Errors +error: Capability ImageRect is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability ImageRect + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 376 diff --git a/3rdparty/glslang/Test/baseResults/spv.imageLoadStoreLod.frag.out b/3rdparty/glslang/Test/baseResults/spv.imageLoadStoreLod.frag.out index 4c65a3674..db9177d08 100644 --- a/3rdparty/glslang/Test/baseResults/spv.imageLoadStoreLod.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.imageLoadStoreLod.frag.out @@ -1,4 +1,8 @@ spv.imageLoadStoreLod.frag +error: SPIRV-Tools Validation Errors +error: Image Operand Lod can only be used with ExplicitLod opcodes and OpImageFetch + %19 = OpImageRead %v4float %15 %int_1 Lod %int_3 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 82 diff --git a/3rdparty/glslang/Test/baseResults/spv.int16.amd.frag.out b/3rdparty/glslang/Test/baseResults/spv.int16.amd.frag.out index c404375bf..ab0861479 100644 --- a/3rdparty/glslang/Test/baseResults/spv.int16.amd.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.int16.amd.frag.out @@ -1,4 +1,8 @@ spv.int16.amd.frag +error: SPIRV-Tools Validation Errors +error: Capability Float16 is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability Float16 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 560 diff --git a/3rdparty/glslang/Test/baseResults/spv.int16.frag.out b/3rdparty/glslang/Test/baseResults/spv.int16.frag.out index 84128ab48..11818b75e 100644 --- a/3rdparty/glslang/Test/baseResults/spv.int16.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.int16.frag.out @@ -1,4 +1,8 @@ spv.int16.frag +error: SPIRV-Tools Validation Errors +error: Capability Float16 is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability Float16 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 523 diff --git a/3rdparty/glslang/Test/baseResults/spv.int32.frag.out b/3rdparty/glslang/Test/baseResults/spv.int32.frag.out index d72de0deb..3b9342844 100644 --- a/3rdparty/glslang/Test/baseResults/spv.int32.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.int32.frag.out @@ -1,4 +1,8 @@ spv.int32.frag +error: SPIRV-Tools Validation Errors +error: Capability Float16 is not allowed by Vulkan 1.1 specification (or requires extension) + OpCapability Float16 + // Module Version 10300 // Generated by (magic number): 80007 // Id's are bound by 493 diff --git a/3rdparty/glslang/Test/baseResults/spv.int8.frag.out b/3rdparty/glslang/Test/baseResults/spv.int8.frag.out index 14922b2b5..1c65384f3 100644 --- a/3rdparty/glslang/Test/baseResults/spv.int8.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.int8.frag.out @@ -1,4 +1,8 @@ spv.int8.frag +error: SPIRV-Tools Validation Errors +error: Capability Float16 is not allowed by Vulkan 1.1 specification (or requires extension) + OpCapability Float16 + // Module Version 10300 // Generated by (magic number): 80007 // Id's are bound by 518 diff --git a/3rdparty/glslang/Test/baseResults/spv.length.frag.out b/3rdparty/glslang/Test/baseResults/spv.length.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.localAggregates.frag.out b/3rdparty/glslang/Test/baseResults/spv.localAggregates.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.loops.frag.out b/3rdparty/glslang/Test/baseResults/spv.loops.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.loopsArtificial.frag.out b/3rdparty/glslang/Test/baseResults/spv.loopsArtificial.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.matFun.vert.out b/3rdparty/glslang/Test/baseResults/spv.matFun.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.memoryQualifier.frag.out b/3rdparty/glslang/Test/baseResults/spv.memoryQualifier.frag.out index 02783b983..4113cc950 100644 --- a/3rdparty/glslang/Test/baseResults/spv.memoryQualifier.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.memoryQualifier.frag.out @@ -1,4 +1,8 @@ spv.memoryQualifier.frag +error: SPIRV-Tools Validation Errors +error: Capability ImageRect is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability ImageRect + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 97 diff --git a/3rdparty/glslang/Test/baseResults/spv.multiStruct.comp.out b/3rdparty/glslang/Test/baseResults/spv.multiStruct.comp.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.multiStructFuncall.frag.out b/3rdparty/glslang/Test/baseResults/spv.multiStructFuncall.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.multiviewPerViewAttributes.tesc.out b/3rdparty/glslang/Test/baseResults/spv.multiviewPerViewAttributes.tesc.out index 5d4508be4..7874b9464 100644 --- a/3rdparty/glslang/Test/baseResults/spv.multiviewPerViewAttributes.tesc.out +++ b/3rdparty/glslang/Test/baseResults/spv.multiviewPerViewAttributes.tesc.out @@ -1,4 +1,8 @@ spv.multiviewPerViewAttributes.tesc +error: SPIRV-Tools Validation Errors +error: OpMemberName Member '5' index is larger than Type '27[gl_PositionPerViewNV]'s member count. + OpMemberName %gl_PerVertex_0 5 "gl_PositionPerViewNV" + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 37 diff --git a/3rdparty/glslang/Test/baseResults/spv.newTexture.frag.out b/3rdparty/glslang/Test/baseResults/spv.newTexture.frag.out old mode 100755 new mode 100644 index 5ddd8a545..5e462bed4 --- a/3rdparty/glslang/Test/baseResults/spv.newTexture.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.newTexture.frag.out @@ -1,4 +1,8 @@ spv.newTexture.frag +error: SPIRV-Tools Validation Errors +error: Capability SampledRect is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability SampledRect + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 284 diff --git a/3rdparty/glslang/Test/baseResults/spv.noWorkgroup.comp.out b/3rdparty/glslang/Test/baseResults/spv.noWorkgroup.comp.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.nonSquare.vert.out b/3rdparty/glslang/Test/baseResults/spv.nonSquare.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.nonuniform.frag.out b/3rdparty/glslang/Test/baseResults/spv.nonuniform.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.offsets.frag.out b/3rdparty/glslang/Test/baseResults/spv.offsets.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.paramMemory.frag.out b/3rdparty/glslang/Test/baseResults/spv.paramMemory.frag.out old mode 100755 new mode 100644 index b593ce9c2..a7e627a2c --- a/3rdparty/glslang/Test/baseResults/spv.paramMemory.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.paramMemory.frag.out @@ -1,4 +1,8 @@ spv.paramMemory.frag +error: SPIRV-Tools Validation Errors +error: OpFunctionCall Argument '38[image1]'s type does not match Function '8's parameter type. + %41 = OpFunctionCall %v4float %image_load_I21_vi2_ %image1 %param + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 69 diff --git a/3rdparty/glslang/Test/baseResults/spv.precision.frag.out b/3rdparty/glslang/Test/baseResults/spv.precision.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.precisionNonESSamp.frag.out b/3rdparty/glslang/Test/baseResults/spv.precisionNonESSamp.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.prepost.frag.out b/3rdparty/glslang/Test/baseResults/spv.prepost.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.pushConstantAnon.vert.out b/3rdparty/glslang/Test/baseResults/spv.pushConstantAnon.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.qualifiers.vert.out b/3rdparty/glslang/Test/baseResults/spv.qualifiers.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.queryL.frag.out b/3rdparty/glslang/Test/baseResults/spv.queryL.frag.out old mode 100755 new mode 100644 index 33f0d95cd..b737a35b7 --- a/3rdparty/glslang/Test/baseResults/spv.queryL.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.queryL.frag.out @@ -1,4 +1,8 @@ spv.queryL.frag +error: SPIRV-Tools Validation Errors +error: Capability SampledRect is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability SampledRect + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 224 diff --git a/3rdparty/glslang/Test/baseResults/spv.rankShift.comp.out b/3rdparty/glslang/Test/baseResults/spv.rankShift.comp.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.sample.frag.out b/3rdparty/glslang/Test/baseResults/spv.sample.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.sampleId.frag.out b/3rdparty/glslang/Test/baseResults/spv.sampleId.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.sampleMaskOverrideCoverage.frag.out b/3rdparty/glslang/Test/baseResults/spv.sampleMaskOverrideCoverage.frag.out index 470cd423f..ae7e8241b 100644 --- a/3rdparty/glslang/Test/baseResults/spv.sampleMaskOverrideCoverage.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.sampleMaskOverrideCoverage.frag.out @@ -1,4 +1,8 @@ spv.sampleMaskOverrideCoverage.frag +error: SPIRV-Tools Validation Errors +error: Operand 2 of Decorate requires one of these capabilities: SampleMaskOverrideCoverageNV + OpDecorate %gl_SampleMask OverrideCoverageNV + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 20 diff --git a/3rdparty/glslang/Test/baseResults/spv.samplePosition.frag.out b/3rdparty/glslang/Test/baseResults/spv.samplePosition.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.separate.frag.out b/3rdparty/glslang/Test/baseResults/spv.separate.frag.out index 346de5305..b9fefd709 100644 --- a/3rdparty/glslang/Test/baseResults/spv.separate.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.separate.frag.out @@ -1,4 +1,8 @@ spv.separate.frag +error: SPIRV-Tools Validation Errors +error: Capability SampledRect is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability SampledRect + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 319 diff --git a/3rdparty/glslang/Test/baseResults/spv.set.vert.out b/3rdparty/glslang/Test/baseResults/spv.set.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.shaderBallotAMD.comp.out b/3rdparty/glslang/Test/baseResults/spv.shaderBallotAMD.comp.out index df70095c1..a28791e66 100644 --- a/3rdparty/glslang/Test/baseResults/spv.shaderBallotAMD.comp.out +++ b/3rdparty/glslang/Test/baseResults/spv.shaderBallotAMD.comp.out @@ -1,4 +1,8 @@ spv.shaderBallotAMD.comp +error: SPIRV-Tools Validation Errors +error: Capability Float16 is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability Float16 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 1343 diff --git a/3rdparty/glslang/Test/baseResults/spv.simpleFunctionCall.frag.out b/3rdparty/glslang/Test/baseResults/spv.simpleFunctionCall.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.simpleMat.vert.out b/3rdparty/glslang/Test/baseResults/spv.simpleMat.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.sparseTexture.frag.out b/3rdparty/glslang/Test/baseResults/spv.sparseTexture.frag.out index d94f64386..78a2c2e75 100644 --- a/3rdparty/glslang/Test/baseResults/spv.sparseTexture.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.sparseTexture.frag.out @@ -1,4 +1,8 @@ spv.sparseTexture.frag +error: SPIRV-Tools Validation Errors +error: Capability SampledRect is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability SampledRect + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 438 diff --git a/3rdparty/glslang/Test/baseResults/spv.sparseTextureClamp.frag.out b/3rdparty/glslang/Test/baseResults/spv.sparseTextureClamp.frag.out index f63fd2f5c..fe210f749 100644 --- a/3rdparty/glslang/Test/baseResults/spv.sparseTextureClamp.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.sparseTextureClamp.frag.out @@ -1,4 +1,8 @@ spv.sparseTextureClamp.frag +error: SPIRV-Tools Validation Errors +error: Capability SampledRect is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability SampledRect + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 360 diff --git a/3rdparty/glslang/Test/baseResults/spv.specConst.vert.out b/3rdparty/glslang/Test/baseResults/spv.specConst.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.ssboAlias.frag.out b/3rdparty/glslang/Test/baseResults/spv.ssboAlias.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.stereoViewRendering.tesc.out b/3rdparty/glslang/Test/baseResults/spv.stereoViewRendering.tesc.out index c01467e7b..732e5b4ca 100644 --- a/3rdparty/glslang/Test/baseResults/spv.stereoViewRendering.tesc.out +++ b/3rdparty/glslang/Test/baseResults/spv.stereoViewRendering.tesc.out @@ -1,4 +1,8 @@ spv.stereoViewRendering.tesc +error: SPIRV-Tools Validation Errors +error: When BuiltIn decoration is applied to a structure-type member, all members of that structure type must also be decorated with BuiltIn (No allowed mixing of built-in variables and non-built-in variables within a single structure). Structure id 27 does not meet this requirement. + %gl_PerVertex_0 = OpTypeStruct %v4float %float %_arr_float_uint_1 %_arr_float_uint_1 %v4float + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 38 diff --git a/3rdparty/glslang/Test/baseResults/spv.storageBuffer.vert.out b/3rdparty/glslang/Test/baseResults/spv.storageBuffer.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.structAssignment.frag.out b/3rdparty/glslang/Test/baseResults/spv.structAssignment.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.structDeref.frag.out b/3rdparty/glslang/Test/baseResults/spv.structDeref.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.structure.frag.out b/3rdparty/glslang/Test/baseResults/spv.structure.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.subgroupClusteredNeg.comp.out b/3rdparty/glslang/Test/baseResults/spv.subgroupClusteredNeg.comp.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.subgroupPartitioned.comp.out b/3rdparty/glslang/Test/baseResults/spv.subgroupPartitioned.comp.out old mode 100755 new mode 100644 index 527a62e6c..f65d9962c --- a/3rdparty/glslang/Test/baseResults/spv.subgroupPartitioned.comp.out +++ b/3rdparty/glslang/Test/baseResults/spv.subgroupPartitioned.comp.out @@ -1,4 +1,8 @@ spv.subgroupPartitioned.comp +error: SPIRV-Tools Validation Errors +error: Opcode GroupNonUniformFAdd requires one of these capabilities: GroupNonUniformArithmetic GroupNonUniformClustered + %179 = OpGroupNonUniformFAdd %float %uint_3 PartitionedReduceNV %176 %177 + // Module Version 10300 // Generated by (magic number): 80007 // Id's are bound by 2506 diff --git a/3rdparty/glslang/Test/baseResults/spv.switch.frag.out b/3rdparty/glslang/Test/baseResults/spv.switch.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.swizzle.frag.out b/3rdparty/glslang/Test/baseResults/spv.swizzle.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.swizzleInversion.frag.out b/3rdparty/glslang/Test/baseResults/spv.swizzleInversion.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.texture.frag.out b/3rdparty/glslang/Test/baseResults/spv.texture.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.texture.vert.out b/3rdparty/glslang/Test/baseResults/spv.texture.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.textureBuffer.vert.out b/3rdparty/glslang/Test/baseResults/spv.textureBuffer.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.textureGatherBiasLod.frag.out b/3rdparty/glslang/Test/baseResults/spv.textureGatherBiasLod.frag.out index faa4dba23..d01515dc5 100644 --- a/3rdparty/glslang/Test/baseResults/spv.textureGatherBiasLod.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.textureGatherBiasLod.frag.out @@ -1,4 +1,8 @@ spv.textureGatherBiasLod.frag +error: SPIRV-Tools Validation Errors +error: Image Operand Bias can only be used with ImplicitLod opcodes + %27 = OpImageGather %v4float %17 %21 %int_0 Bias %26 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 298 diff --git a/3rdparty/glslang/Test/baseResults/spv.types.frag.out b/3rdparty/glslang/Test/baseResults/spv.types.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.uint.frag.out b/3rdparty/glslang/Test/baseResults/spv.uint.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.unit1.frag.out b/3rdparty/glslang/Test/baseResults/spv.unit1.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.variableArrayIndex.frag.out b/3rdparty/glslang/Test/baseResults/spv.variableArrayIndex.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.varyingArray.frag.out b/3rdparty/glslang/Test/baseResults/spv.varyingArray.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.varyingArrayIndirect.frag.out b/3rdparty/glslang/Test/baseResults/spv.varyingArrayIndirect.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.vecMatConstruct.frag.out b/3rdparty/glslang/Test/baseResults/spv.vecMatConstruct.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.viewportArray2.tesc.out b/3rdparty/glslang/Test/baseResults/spv.viewportArray2.tesc.out index a9c9ba254..b14179ece 100644 --- a/3rdparty/glslang/Test/baseResults/spv.viewportArray2.tesc.out +++ b/3rdparty/glslang/Test/baseResults/spv.viewportArray2.tesc.out @@ -1,4 +1,8 @@ spv.viewportArray2.tesc +error: SPIRV-Tools Validation Errors +error: Vulkan spec allows BuiltIn ViewportIndex to be used only with Vertex, TessellationEvaluation, Geometry, or Fragment execution models. ID <0> (OpStore) is referencing ID <22> (OpVariable) which is decorated with BuiltIn ViewportIndex in function <4> called with execution model TessellationControl. + OpStore %gl_ViewportIndex %int_2 + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 25 diff --git a/3rdparty/glslang/Test/baseResults/spv.voidFunction.frag.out b/3rdparty/glslang/Test/baseResults/spv.voidFunction.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.vulkan100.subgroupArithmetic.comp.out b/3rdparty/glslang/Test/baseResults/spv.vulkan100.subgroupArithmetic.comp.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.vulkan100.subgroupPartitioned.comp.out b/3rdparty/glslang/Test/baseResults/spv.vulkan100.subgroupPartitioned.comp.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.vulkan110.int16.frag.out b/3rdparty/glslang/Test/baseResults/spv.vulkan110.int16.frag.out old mode 100755 new mode 100644 index b6936f8c0..9141e4ec4 --- a/3rdparty/glslang/Test/baseResults/spv.vulkan110.int16.frag.out +++ b/3rdparty/glslang/Test/baseResults/spv.vulkan110.int16.frag.out @@ -1,4 +1,8 @@ spv.vulkan110.int16.frag +error: SPIRV-Tools Validation Errors +error: Capability Float16 is not allowed by Vulkan 1.1 specification (or requires extension) + OpCapability Float16 + // Module Version 10300 // Generated by (magic number): 80007 // Id's are bound by 523 diff --git a/3rdparty/glslang/Test/baseResults/spv.vulkan110.storageBuffer.vert.out b/3rdparty/glslang/Test/baseResults/spv.vulkan110.storageBuffer.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.while-simple.vert.out b/3rdparty/glslang/Test/baseResults/spv.while-simple.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.whileLoop.frag.out b/3rdparty/glslang/Test/baseResults/spv.whileLoop.frag.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/spv.xfb.vert.out b/3rdparty/glslang/Test/baseResults/spv.xfb.vert.out old mode 100755 new mode 100644 index 3cd93d500..68633e1f5 --- a/3rdparty/glslang/Test/baseResults/spv.xfb.vert.out +++ b/3rdparty/glslang/Test/baseResults/spv.xfb.vert.out @@ -1,4 +1,8 @@ spv.xfb.vert +error: SPIRV-Tools Validation Errors +error: Capability TransformFeedback is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability TransformFeedback + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 16 diff --git a/3rdparty/glslang/Test/baseResults/spv.xfb2.vert.out b/3rdparty/glslang/Test/baseResults/spv.xfb2.vert.out old mode 100755 new mode 100644 index a8551a1a7..6dc398721 --- a/3rdparty/glslang/Test/baseResults/spv.xfb2.vert.out +++ b/3rdparty/glslang/Test/baseResults/spv.xfb2.vert.out @@ -1,4 +1,8 @@ spv.xfb2.vert +error: SPIRV-Tools Validation Errors +error: Capability TransformFeedback is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability TransformFeedback + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 35 diff --git a/3rdparty/glslang/Test/baseResults/spv.xfb3.vert.out b/3rdparty/glslang/Test/baseResults/spv.xfb3.vert.out old mode 100755 new mode 100644 index 0218847e3..1d526aa9f --- a/3rdparty/glslang/Test/baseResults/spv.xfb3.vert.out +++ b/3rdparty/glslang/Test/baseResults/spv.xfb3.vert.out @@ -1,4 +1,8 @@ spv.xfb3.vert +error: SPIRV-Tools Validation Errors +error: Capability TransformFeedback is not allowed by Vulkan 1.0 specification (or requires extension) + OpCapability TransformFeedback + // Module Version 10000 // Generated by (magic number): 80007 // Id's are bound by 35 diff --git a/3rdparty/glslang/Test/baseResults/stringToDouble.vert.out b/3rdparty/glslang/Test/baseResults/stringToDouble.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/tokenPaste.vert.out b/3rdparty/glslang/Test/baseResults/tokenPaste.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/baseResults/vulkan.ast.vert.out b/3rdparty/glslang/Test/baseResults/vulkan.ast.vert.out old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/bump b/3rdparty/glslang/Test/bump index f23be33b6..03df6327b 100755 --- a/3rdparty/glslang/Test/bump +++ b/3rdparty/glslang/Test/bump @@ -1,2 +1,3 @@ +#!/usr/bin/env bash cp localResults/* baseResults/ diff --git a/3rdparty/glslang/Test/compoundsuffix.vert.glsl b/3rdparty/glslang/Test/compoundsuffix.vert.glsl old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/constFold.frag b/3rdparty/glslang/Test/constFold.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/cppBad2.vert b/3rdparty/glslang/Test/cppBad2.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/cppPassMacroName.frag b/3rdparty/glslang/Test/cppPassMacroName.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.amend.frag b/3rdparty/glslang/Test/hlsl.amend.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.boolConv.vert b/3rdparty/glslang/Test/hlsl.boolConv.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.charLit.vert b/3rdparty/glslang/Test/hlsl.charLit.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.flattenSubset.frag b/3rdparty/glslang/Test/hlsl.flattenSubset.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.flattenSubset2.frag b/3rdparty/glslang/Test/hlsl.flattenSubset2.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.hlslOffset.vert b/3rdparty/glslang/Test/hlsl.hlslOffset.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.implicitBool.frag b/3rdparty/glslang/Test/hlsl.implicitBool.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.inf.vert b/3rdparty/glslang/Test/hlsl.inf.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.intrinsics.frag b/3rdparty/glslang/Test/hlsl.intrinsics.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.logicalConvert.frag b/3rdparty/glslang/Test/hlsl.logicalConvert.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.multiEntry.vert b/3rdparty/glslang/Test/hlsl.multiEntry.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.multiReturn.frag b/3rdparty/glslang/Test/hlsl.multiReturn.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.namespace.frag b/3rdparty/glslang/Test/hlsl.namespace.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.nonstaticMemberFunction.frag b/3rdparty/glslang/Test/hlsl.nonstaticMemberFunction.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.partialInit.frag b/3rdparty/glslang/Test/hlsl.partialInit.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.pp.expand.frag b/3rdparty/glslang/Test/hlsl.pp.expand.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.scalarCast.vert b/3rdparty/glslang/Test/hlsl.scalarCast.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.shapeConvRet.frag b/3rdparty/glslang/Test/hlsl.shapeConvRet.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.staticFuncInit.frag b/3rdparty/glslang/Test/hlsl.staticFuncInit.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.staticMemberFunction.frag b/3rdparty/glslang/Test/hlsl.staticMemberFunction.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.string.frag b/3rdparty/glslang/Test/hlsl.string.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.structIoFourWay.frag b/3rdparty/glslang/Test/hlsl.structIoFourWay.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.structStructName.frag b/3rdparty/glslang/Test/hlsl.structStructName.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/hlsl.this.frag b/3rdparty/glslang/Test/hlsl.this.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/nosuffix b/3rdparty/glslang/Test/nosuffix old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/preprocessor.bad_arg.vert b/3rdparty/glslang/Test/preprocessor.bad_arg.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/runtests b/3rdparty/glslang/Test/runtests index ddd77ce8e..d36f3a8cf 100755 --- a/3rdparty/glslang/Test/runtests +++ b/3rdparty/glslang/Test/runtests @@ -130,13 +130,13 @@ diff -b $BASEDIR/spv.looseUniformNoLoc.vert.out $TARGETDIR/spv.looseUniformNoLoc # Testing debug information # echo Testing SPV Debug Information -$EXE -g --relaxed-errors --suppress-warnings --aml --amb --hlsl-offsets --nsf \ +$EXE -g --relaxed-errors --suppress-warnings --aml --amb --hlsl-offsets --nsf --spirv-val \ -G -H spv.debugInfo.frag --rsb frag 3 > $TARGETDIR/spv.debugInfo.frag.out diff -b $BASEDIR/spv.debugInfo.frag.out $TARGETDIR/spv.debugInfo.frag.out || HASERROR=1 -$EXE -g -Od --target-env vulkan1.1 --relaxed-errors --suppress-warnings --aml --hlsl-offsets --nsf \ +$EXE -g -Od --target-env vulkan1.1 --relaxed-errors --suppress-warnings --aml --hlsl-offsets --nsf --spirv-val \ -G -H spv.debugInfo.frag --rsb frag 3 > $TARGETDIR/spv.debugInfo.1.1.frag.out diff -b $BASEDIR/spv.debugInfo.1.1.frag.out $TARGETDIR/spv.debugInfo.1.1.frag.out || HASERROR=1 -$EXE -g -D -Od -e newMain -g --amb --aml --fua --hlsl-iomap --nsf --sib 1 --ssb 2 --sbb 3 --stb 4 --suavb 5 --sub 6 \ +$EXE -g -D -Od -e newMain -g --amb --aml --fua --hlsl-iomap --nsf --spirv-val --sib 1 --ssb 2 --sbb 3 --stb 4 --suavb 5 --sub 6 \ --sep origMain -H -Od spv.hlslDebugInfo.vert --rsb vert t0 0 0 > $TARGETDIR/spv.hlslDebugInfo.frag.out diff -b $BASEDIR/spv.hlslDebugInfo.frag.out $TARGETDIR/spv.hlslDebugInfo.frag.out || HASERROR=1 diff --git a/3rdparty/glslang/Test/runtimeArray.vert b/3rdparty/glslang/Test/runtimeArray.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.16bitstorage-int.frag b/3rdparty/glslang/Test/spv.16bitstorage-int.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.16bitstorage-uint.frag b/3rdparty/glslang/Test/spv.16bitstorage-uint.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.16bitstorage.frag b/3rdparty/glslang/Test/spv.16bitstorage.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.16bitstorage_Error-int.frag b/3rdparty/glslang/Test/spv.16bitstorage_Error-int.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.16bitstorage_Error-uint.frag b/3rdparty/glslang/Test/spv.16bitstorage_Error-uint.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.16bitstorage_Error.frag b/3rdparty/glslang/Test/spv.16bitstorage_Error.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.8bitstorage-int.frag b/3rdparty/glslang/Test/spv.8bitstorage-int.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.8bitstorage-uint.frag b/3rdparty/glslang/Test/spv.8bitstorage-uint.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.8bitstorage_Error-int.frag b/3rdparty/glslang/Test/spv.8bitstorage_Error-int.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.8bitstorage_Error-uint.frag b/3rdparty/glslang/Test/spv.8bitstorage_Error-uint.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.functionNestedOpaque.vert b/3rdparty/glslang/Test/spv.functionNestedOpaque.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.hlslOffsets.vert b/3rdparty/glslang/Test/spv.hlslOffsets.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.multiStructFuncall.frag b/3rdparty/glslang/Test/spv.multiStructFuncall.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.offsets.frag b/3rdparty/glslang/Test/spv.offsets.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.shortCircuit.frag b/3rdparty/glslang/Test/spv.shortCircuit.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.subgroupPartitioned.comp b/3rdparty/glslang/Test/spv.subgroupPartitioned.comp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.targetOpenGL.vert b/3rdparty/glslang/Test/spv.targetOpenGL.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.targetVulkan.vert b/3rdparty/glslang/Test/spv.targetVulkan.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.unit1.frag b/3rdparty/glslang/Test/spv.unit1.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.unit2.frag b/3rdparty/glslang/Test/spv.unit2.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.unit3.frag b/3rdparty/glslang/Test/spv.unit3.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/spv.vulkan100.subgroupPartitioned.comp b/3rdparty/glslang/Test/spv.vulkan100.subgroupPartitioned.comp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/stringToDouble.vert b/3rdparty/glslang/Test/stringToDouble.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/tokenLength.vert b/3rdparty/glslang/Test/tokenLength.vert old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/Test/vulkan.frag b/3rdparty/glslang/Test/vulkan.frag old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/CMakeLists.txt b/3rdparty/glslang/glslang/CMakeLists.txt old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/Include/Types.h b/3rdparty/glslang/glslang/Include/Types.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/Include/intermediate.h b/3rdparty/glslang/glslang/Include/intermediate.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/Include/revision.h b/3rdparty/glslang/glslang/Include/revision.h index 2ff52ded3..f2661d45d 100644 --- a/3rdparty/glslang/glslang/Include/revision.h +++ b/3rdparty/glslang/glslang/Include/revision.h @@ -1,3 +1,3 @@ // This header is generated by the make-revision script. -#define GLSLANG_PATCH_LEVEL 2853 +#define GLSLANG_PATCH_LEVEL 2870 diff --git a/3rdparty/glslang/glslang/MachineIndependent/Constant.cpp b/3rdparty/glslang/glslang/MachineIndependent/Constant.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/Initialize.cpp b/3rdparty/glslang/glslang/MachineIndependent/Initialize.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/Intermediate.cpp b/3rdparty/glslang/glslang/MachineIndependent/Intermediate.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/ParseHelper.cpp b/3rdparty/glslang/glslang/MachineIndependent/ParseHelper.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/ParseHelper.h b/3rdparty/glslang/glslang/MachineIndependent/ParseHelper.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/Scan.cpp b/3rdparty/glslang/glslang/MachineIndependent/Scan.cpp old mode 100755 new mode 100644 index 66ac4a85e..eaedbe810 --- a/3rdparty/glslang/glslang/MachineIndependent/Scan.cpp +++ b/3rdparty/glslang/glslang/MachineIndependent/Scan.cpp @@ -683,7 +683,7 @@ void TScanContext::fillInKeywordMap() (*KeywordMap)["smooth"] = SMOOTH; (*KeywordMap)["flat"] = FLAT; #ifdef AMD_EXTENSIONS - (*KeywordMap)["__explicitInterpAMD"] = __EXPLICITINTERPAMD; + (*KeywordMap)["__explicitInterpAMD"] = EXPLICITINTERPAMD; #endif (*KeywordMap)["centroid"] = CENTROID; (*KeywordMap)["precise"] = PRECISE; @@ -1490,7 +1490,7 @@ int TScanContext::tokenizeIdentifier() return keyword; #ifdef AMD_EXTENSIONS - case __EXPLICITINTERPAMD: + case EXPLICITINTERPAMD: if (parseContext.profile != EEsProfile && parseContext.version >= 450 && parseContext.extensionTurnedOn(E_GL_AMD_shader_explicit_vertex_parameter)) return keyword; diff --git a/3rdparty/glslang/glslang/MachineIndependent/ScanContext.h b/3rdparty/glslang/glslang/MachineIndependent/ScanContext.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/Versions.cpp b/3rdparty/glslang/glslang/MachineIndependent/Versions.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/Versions.h b/3rdparty/glslang/glslang/MachineIndependent/Versions.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/glslang.y b/3rdparty/glslang/glslang/MachineIndependent/glslang.y old mode 100755 new mode 100644 index 4d16b2b37..66a53d407 --- a/3rdparty/glslang/glslang/MachineIndependent/glslang.y +++ b/3rdparty/glslang/glslang/MachineIndependent/glslang.y @@ -146,7 +146,7 @@ extern int yylex(YYSTYPE*, TParseContext&); %token F16VEC2 F16VEC3 F16VEC4 F16MAT2 F16MAT3 F16MAT4 %token F32VEC2 F32VEC3 F32VEC4 F32MAT2 F32MAT3 F32MAT4 %token F64VEC2 F64VEC3 F64VEC4 F64MAT2 F64MAT3 F64MAT4 -%token NOPERSPECTIVE FLAT SMOOTH LAYOUT __EXPLICITINTERPAMD +%token NOPERSPECTIVE FLAT SMOOTH LAYOUT EXPLICITINTERPAMD %token MAT2X2 MAT2X3 MAT2X4 %token MAT3X2 MAT3X3 MAT3X4 @@ -1135,7 +1135,7 @@ interpolation_qualifier $$.init($1.loc); $$.qualifier.nopersp = true; } - | __EXPLICITINTERPAMD { + | EXPLICITINTERPAMD { #ifdef AMD_EXTENSIONS parseContext.globalCheck($1.loc, "__explicitInterpAMD"); parseContext.profileRequires($1.loc, ECoreProfile, 450, E_GL_AMD_shader_explicit_vertex_parameter, "explicit interpolation"); diff --git a/3rdparty/glslang/glslang/MachineIndependent/glslang_tab.cpp b/3rdparty/glslang/glslang/MachineIndependent/glslang_tab.cpp old mode 100755 new mode 100644 index 96e245edb..a79c480e6 --- a/3rdparty/glslang/glslang/MachineIndependent/glslang_tab.cpp +++ b/3rdparty/glslang/glslang/MachineIndependent/glslang_tab.cpp @@ -236,7 +236,7 @@ extern int yydebug; FLAT = 368, SMOOTH = 369, LAYOUT = 370, - __EXPLICITINTERPAMD = 371, + EXPLICITINTERPAMD = 371, MAT2X2 = 372, MAT2X3 = 373, MAT2X4 = 374, @@ -981,7 +981,7 @@ static const char *const yytname[] = "F16VEC2", "F16VEC3", "F16VEC4", "F16MAT2", "F16MAT3", "F16MAT4", "F32VEC2", "F32VEC3", "F32VEC4", "F32MAT2", "F32MAT3", "F32MAT4", "F64VEC2", "F64VEC3", "F64VEC4", "F64MAT2", "F64MAT3", "F64MAT4", - "NOPERSPECTIVE", "FLAT", "SMOOTH", "LAYOUT", "__EXPLICITINTERPAMD", + "NOPERSPECTIVE", "FLAT", "SMOOTH", "LAYOUT", "EXPLICITINTERPAMD", "MAT2X2", "MAT2X3", "MAT2X4", "MAT3X2", "MAT3X3", "MAT3X4", "MAT4X2", "MAT4X3", "MAT4X4", "DMAT2X2", "DMAT2X3", "DMAT2X4", "DMAT3X2", "DMAT3X3", "DMAT3X4", "DMAT4X2", "DMAT4X3", "DMAT4X4", "F16MAT2X2", diff --git a/3rdparty/glslang/glslang/MachineIndependent/glslang_tab.cpp.h b/3rdparty/glslang/glslang/MachineIndependent/glslang_tab.cpp.h index 7cfb79766..9085dd0c0 100644 --- a/3rdparty/glslang/glslang/MachineIndependent/glslang_tab.cpp.h +++ b/3rdparty/glslang/glslang/MachineIndependent/glslang_tab.cpp.h @@ -158,7 +158,7 @@ extern int yydebug; FLAT = 368, SMOOTH = 369, LAYOUT = 370, - __EXPLICITINTERPAMD = 371, + EXPLICITINTERPAMD = 371, MAT2X2 = 372, MAT2X3 = 373, MAT2X4 = 374, diff --git a/3rdparty/glslang/glslang/MachineIndependent/intermOut.cpp b/3rdparty/glslang/glslang/MachineIndependent/intermOut.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/iomapper.cpp b/3rdparty/glslang/glslang/MachineIndependent/iomapper.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/linkValidate.cpp b/3rdparty/glslang/glslang/MachineIndependent/linkValidate.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/localintermediate.h b/3rdparty/glslang/glslang/MachineIndependent/localintermediate.h old mode 100755 new mode 100644 index 9cbc4c8dc..e0dab9725 --- a/3rdparty/glslang/glslang/MachineIndependent/localintermediate.h +++ b/3rdparty/glslang/glslang/MachineIndependent/localintermediate.h @@ -41,6 +41,8 @@ #include "../Public/ShaderLang.h" #include "Versions.h" +#include +#include #include #include #include @@ -349,7 +351,7 @@ public: if (hlslOffsets) processes.addProcess("hlsl-offsets"); } - bool usingHlslOFfsets() const { return hlslOffsets; } + bool usingHlslOffsets() const { return hlslOffsets; } void setUseStorageBuffer() { useStorageBuffer = true; diff --git a/3rdparty/glslang/glslang/MachineIndependent/parseVersions.h b/3rdparty/glslang/glslang/MachineIndependent/parseVersions.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/preprocessor/Pp.cpp b/3rdparty/glslang/glslang/MachineIndependent/preprocessor/Pp.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/preprocessor/PpContext.cpp b/3rdparty/glslang/glslang/MachineIndependent/preprocessor/PpContext.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/preprocessor/PpContext.h b/3rdparty/glslang/glslang/MachineIndependent/preprocessor/PpContext.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/preprocessor/PpScanner.cpp b/3rdparty/glslang/glslang/MachineIndependent/preprocessor/PpScanner.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/MachineIndependent/preprocessor/PpTokens.cpp b/3rdparty/glslang/glslang/MachineIndependent/preprocessor/PpTokens.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/glslang/Public/ShaderLang.h b/3rdparty/glslang/glslang/Public/ShaderLang.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/gtests/AST.FromFile.cpp b/3rdparty/glslang/gtests/AST.FromFile.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/gtests/Hlsl.FromFile.cpp b/3rdparty/glslang/gtests/Hlsl.FromFile.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/gtests/Link.FromFile.Vk.cpp b/3rdparty/glslang/gtests/Link.FromFile.Vk.cpp old mode 100755 new mode 100644 index beb79e18e..22892f0b4 --- a/3rdparty/glslang/gtests/Link.FromFile.Vk.cpp +++ b/3rdparty/glslang/gtests/Link.FromFile.Vk.cpp @@ -79,6 +79,7 @@ TEST_P(LinkTestVulkan, FromFile) std::vector spirv_binary; glslang::SpvOptions options; options.disableOptimizer = true; + options.validate = true; glslang::GlslangToSpv(*program.getIntermediate(shaders.front()->getStage()), spirv_binary, &logger, &options); diff --git a/3rdparty/glslang/gtests/Pp.FromFile.cpp b/3rdparty/glslang/gtests/Pp.FromFile.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/gtests/Spv.FromFile.cpp b/3rdparty/glslang/gtests/Spv.FromFile.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/gtests/TestFixture.h b/3rdparty/glslang/gtests/TestFixture.h old mode 100644 new mode 100755 index a58978d33..3329fa3c2 --- a/3rdparty/glslang/gtests/TestFixture.h +++ b/3rdparty/glslang/gtests/TestFixture.h @@ -243,6 +243,7 @@ public: std::vector spirv_binary; glslang::SpvOptions options; options.disableOptimizer = !enableOptimizer; + options.validate = true; glslang::GlslangToSpv(*program.getIntermediate(stage), spirv_binary, &logger, &options); @@ -298,8 +299,10 @@ public: if (success && (controls & EShMsgSpvRules)) { std::vector spirv_binary; + glslang::SpvOptions options; + options.validate = true; glslang::GlslangToSpv(*program.getIntermediate(stage), - spirv_binary, &logger); + spirv_binary, &logger, &options); std::ostringstream disassembly_stream; spv::Parameterize(); @@ -338,8 +341,10 @@ public: if (success && (controls & EShMsgSpvRules)) { std::vector spirv_binary; + glslang::SpvOptions options; + options.validate = true; glslang::GlslangToSpv(*program.getIntermediate(stage), - spirv_binary, &logger); + spirv_binary, &logger, &options); spv::spirvbin_t(0 /*verbosity*/).remap(spirv_binary, remapOptions); diff --git a/3rdparty/glslang/hlsl/CMakeLists.txt b/3rdparty/glslang/hlsl/CMakeLists.txt old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/hlsl/hlslGrammar.cpp b/3rdparty/glslang/hlsl/hlslGrammar.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/hlsl/hlslGrammar.h b/3rdparty/glslang/hlsl/hlslGrammar.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/hlsl/hlslOpMap.cpp b/3rdparty/glslang/hlsl/hlslOpMap.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/hlsl/hlslOpMap.h b/3rdparty/glslang/hlsl/hlslOpMap.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/hlsl/hlslParseHelper.cpp b/3rdparty/glslang/hlsl/hlslParseHelper.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/hlsl/hlslParseHelper.h b/3rdparty/glslang/hlsl/hlslParseHelper.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/hlsl/hlslParseables.cpp b/3rdparty/glslang/hlsl/hlslParseables.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/hlsl/hlslParseables.h b/3rdparty/glslang/hlsl/hlslParseables.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/hlsl/hlslScanContext.cpp b/3rdparty/glslang/hlsl/hlslScanContext.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/hlsl/hlslScanContext.h b/3rdparty/glslang/hlsl/hlslScanContext.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/hlsl/hlslTokenStream.cpp b/3rdparty/glslang/hlsl/hlslTokenStream.cpp old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/hlsl/hlslTokenStream.h b/3rdparty/glslang/hlsl/hlslTokenStream.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/hlsl/hlslTokens.h b/3rdparty/glslang/hlsl/hlslTokens.h old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/known_good.json b/3rdparty/glslang/known_good.json old mode 100644 new mode 100755 index 0445eabc0..e7a83ce7d --- a/3rdparty/glslang/known_good.json +++ b/3rdparty/glslang/known_good.json @@ -5,7 +5,7 @@ "site" : "github", "subrepo" : "KhronosGroup/SPIRV-Tools", "subdir" : "External/spirv-tools", - "commit" : "714bf84e58abd9573488fc365707fb8f288ca73c" + "commit" : "6d27a8350fbc339909834a6ef339c805cb1ab69b" }, { "name" : "spirv-tools/external/spirv-headers", diff --git a/3rdparty/glslang/known_good_khr.json b/3rdparty/glslang/known_good_khr.json old mode 100755 new mode 100644 diff --git a/3rdparty/glslang/update_glslang_sources.py b/3rdparty/glslang/update_glslang_sources.py index a1cc0380a..65be2f6a2 100755 --- a/3rdparty/glslang/update_glslang_sources.py +++ b/3rdparty/glslang/update_glslang_sources.py @@ -96,7 +96,7 @@ class GoodCommit(object): def AddRemote(self): """Add the remote 'known-good' if it does not exist.""" remotes = command_output(['git', 'remote'], self.subdir).splitlines() - if 'known-good' not in remotes: + if b'known-good' not in remotes: command_output(['git', 'remote', 'add', 'known-good', self.GetUrl()], self.subdir) def HasCommit(self): diff --git a/3rdparty/spirv-tools/.appveyor.yml b/3rdparty/spirv-tools/.appveyor.yml deleted file mode 100644 index 5a4934876..000000000 --- a/3rdparty/spirv-tools/.appveyor.yml +++ /dev/null @@ -1,89 +0,0 @@ -# Windows Build Configuration for AppVeyor -# http://www.appveyor.com/docs/appveyor-yml - -# version format -version: "{build}" - -# The most recent compiler gives the most interesting new results. -# Put it first so we get its feedback first. -os: - - Visual Studio 2017 - - Visual Studio 2013 - -platform: - - x64 - -configuration: - - Debug - - Release - -branches: - only: - - master - -# Travis advances the master-tot tag to current top of the tree after -# each push into the master branch, because it relies on that tag to -# upload build artifacts to the master-tot release. This will cause -# double testing for each push on Appveyor: one for the push, one for -# the tag advance. Disable testing tags. -skip_tags: true - -clone_depth: 1 - -matrix: - fast_finish: true # Show final status immediately if a test fails. - exclude: - - os: Visual Studio 2013 - configuration: Debug - -# scripts that run after cloning repository -install: - # Install ninja - - set NINJA_URL="https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-win.zip" - - appveyor DownloadFile %NINJA_URL% -FileName ninja.zip - - 7z x ninja.zip -oC:\ninja > nul - - set PATH=C:\ninja;%PATH% - -before_build: - - git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers.git external/spirv-headers - - git clone --depth=1 https://github.com/google/googletest.git external/googletest - - git clone --depth=1 https://github.com/google/effcee.git external/effcee - - git clone --depth=1 https://github.com/google/re2.git external/re2 - # Set path and environment variables for the current Visual Studio version - - if "%APPVEYOR_BUILD_WORKER_IMAGE%"=="Visual Studio 2013" (call "C:\Program Files (x86)\Microsoft Visual Studio 12.0\VC\vcvarsall.bat" x86_amd64) - - if "%APPVEYOR_BUILD_WORKER_IMAGE%"=="Visual Studio 2017" (call "C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsall.bat" x86_amd64) - -build: - parallel: true # enable MSBuild parallel builds - verbosity: minimal - -build_script: - - mkdir build && cd build - - cmake -GNinja -DSPIRV_BUILD_COMPRESSION=ON -DCMAKE_BUILD_TYPE=%CONFIGURATION% -DCMAKE_INSTALL_PREFIX=install -DRE2_BUILD_TESTING=OFF .. - - ninja install - -test_script: - - ctest -C %CONFIGURATION% --output-on-failure --timeout 300 - -after_test: - # Zip build artifacts for uploading and deploying - - cd install - - 7z a SPIRV-Tools-master-windows-"%PLATFORM%"-"%CONFIGURATION%".zip *\* - -artifacts: - - path: build\install\*.zip - name: artifacts-zip - -deploy: - - provider: GitHub - auth_token: - secure: TMfcScKzzFIm1YgeV/PwCRXFDCw8Xm0wY2Vb2FU6WKlbzb5eUITTpr6I5vHPnAxS - release: master-tot - description: "Continuous build of the latest master branch by Appveyor and Travis CI" - artifact: artifacts-zip - draft: false - prerelease: false - force_update: true - on: - branch: master - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2017 diff --git a/3rdparty/spirv-tools/.clang-format b/3rdparty/spirv-tools/.clang-format deleted file mode 100644 index 2b5d4a50b..000000000 --- a/3rdparty/spirv-tools/.clang-format +++ /dev/null @@ -1,5 +0,0 @@ ---- -Language: Cpp -BasedOnStyle: Google -DerivePointerAlignment: false -... diff --git a/3rdparty/spirv-tools/.gitignore b/3rdparty/spirv-tools/.gitignore deleted file mode 100644 index e5d8db4b1..000000000 --- a/3rdparty/spirv-tools/.gitignore +++ /dev/null @@ -1,15 +0,0 @@ -/build* -.ycm_extra_conf.py* -compile_commands.json -/external/googletest -/external/SPIRV-Headers -/external/spirv-headers -/external/effcee -/external/re2 -/TAGS -/.clang_complete -/utils/clang-format-diff.py - -# Vim -[._]*.s[a-w][a-z] -*~ diff --git a/3rdparty/spirv-tools/.travis.yml b/3rdparty/spirv-tools/.travis.yml deleted file mode 100644 index b4f87cd7b..000000000 --- a/3rdparty/spirv-tools/.travis.yml +++ /dev/null @@ -1,133 +0,0 @@ -# Linux Build Configuration for Travis - -language: cpp - -os: - - linux - - osx - -# Use Ubuntu 14.04 LTS (Trusty) as the Linux testing environment. -dist: trusty -sudo: false - -env: - global: - - secure: IoR/Xe9E+NnLAeI23WrmUsGQn5rocz+XRYUk+BbaoKiIRYm4q72GKyypRoOGLu7wImOXFSvnN/dpdnqIpx4W0NfsSvNdlXyhDy+wvT1kzTt77dJGnkGZTZ2SBOtC9AECLy4sqM9HG0rYRR6WfXcnP2GlrE5f2aF07aISQbOUsQMvyyhtCmVAzIigK1zIUto5I0pNenvo/Y+ur+mEvTh+FtaoDIGepCbZlCc+OxqRXwXNlI7mDXbzLPmTB1FWTGsrZdRX8czF9tN9Y+T79DQjB4Lcyyeow8yU9NBVlgzZJcp1xI0UIskRT8gVrXmBYL2dMeHnDQuhxjEg9n7jfr3ptA9rgwMaSsgdaLwuBXgtPuqVgUYDpE1cP8WI8q38MXX0I6psTs/WHu+z+5UwfjzpPOHmGdVt48o8ymFTapvD5Cf1+uJyk73QkyStnPIdBF1N9Yx5sD7HN28K6/Ro12sCCePHUZ9Uz1DdZI6XxkgCNKNwao0csAyvODxD6Ee43mkExtviB8BJY5jWLIMTdGhgEGH2sRqils8IDW0p8AOTPM4UC7iA7hdg3pA+XMvBHvP9ixsY7tuB+yR2AfnFaSw2DVbwI5GgFdFMNHXYuL+9V9Wuh3keBKYQT/Hy1YvxjQ/t9UouYHqEsyVFUl3R4lEAM9+qSRsRu+EKmcSO2QtCsWc= - matrix: - # Each line is a set of environment variables set before a build. - # Thus each line represents a different build configuration. - - BUILD_TYPE=RelWithDebInfo - - BUILD_TYPE=Debug - -compiler: - - clang - - gcc - -matrix: - fast_finish: true - include: - # Additional build using Android NDK with android-cmake - - env: BUILD_ANDROID_CMAKE=ON - # Additional build using Android NDK with Android.mk - - env: BUILD_ANDROID_MK=ON - # Additional check over format - - env: CHECK_FORMAT=ON - exclude: - # Skip GCC builds on macOS. - - os: osx - compiler: gcc - -cache: - apt: true - -git: - depth: 1 - -branches: - only: - - master - -before_install: - - if [[ "$BUILD_ANDROID_CMAKE" == "ON" ]] || [[ "$BUILD_ANDROID_MK" == "ON" ]]; then - git clone --depth=1 https://github.com/urho3d/android-ndk.git $HOME/android-ndk; - export ANDROID_NDK=$HOME/android-ndk; - git clone --depth=1 https://github.com/taka-no-me/android-cmake.git $HOME/android-cmake; - export TOOLCHAIN_PATH=$HOME/android-cmake/android.toolchain.cmake; - fi - - if [[ "$CHECK_FORMAT" == "ON" ]]; then - curl -L http://llvm.org/svn/llvm-project/cfe/trunk/tools/clang-format/clang-format-diff.py -o utils/clang-format-diff.py; - fi - -before_script: - - git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers - - git clone --depth=1 https://github.com/google/googletest external/googletest - - git clone --depth=1 https://github.com/google/effcee external/effcee - - git clone --depth=1 https://github.com/google/re2 external/re2 - -script: - # Due to the limitation of Travis platform, we cannot start too many concurrent jobs. - # Otherwise GCC will panic with internal error, possibility because of memory issues. - # ctest with the current tests doesn't profit from using more than 4 threads. - - export NPROC=4 - - mkdir build && cd build - - if [[ "$BUILD_ANDROID_MK" == "ON" ]]; then - export BUILD_DIR=$(pwd); - mkdir ${BUILD_DIR}/libs; - mkdir ${BUILD_DIR}/app; - $ANDROID_NDK/ndk-build -C ../android_test NDK_PROJECT_PATH=. - NDK_LIBS_OUT=${BUILD_DIR}/libs - NDK_APP_OUT=${BUILD_DIR}/app -j${NPROC}; - elif [[ "$BUILD_ANDROID_CMAKE" == "ON" ]]; then - cmake -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_PATH} - -DANDROID_NATIVE_API_LEVEL=android-9 - -DCMAKE_BUILD_TYPE=Release - -DANDROID_ABI="armeabi-v7a with NEON" - -DSPIRV_BUILD_COMPRESSION=ON - -DSPIRV_SKIP_TESTS=ON ..; - make -j${NPROC}; - elif [[ "$CHECK_FORMAT" == "ON" ]]; then - cd ..; - ./utils/check_code_format.sh; - else - cmake -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSPIRV_BUILD_COMPRESSION=ON -DCMAKE_INSTALL_PREFIX=install ..; - make -j${NPROC} install; - ctest -j${NPROC} --output-on-failure --timeout 300; - fi - -after_success: - # Create tarball for deployment - - if [[ "${CC}" == "clang" && "${BUILD_ANDROID_MK}" != "ON" && "${BUILD_ANDROID_CMAKE}" != "ON" && "${CHECK_FORMAT}" != "ON" ]]; then - cd install; - export TARBALL=SPIRV-Tools-master-${TRAVIS_OS_NAME}-${BUILD_TYPE}.zip; - find . -print | zip -@ ${TARBALL}; - fi - -before_deploy: - # Tag the current master top of the tree as "master-tot". - # Travis CI relies on the tag name to push to the correct release. - - git config --global user.name "Travis CI" - - git config --global user.email "builds@travis-ci.org" - - git tag -f master-tot - - git push -q -f https://${spirvtoken}@github.com/KhronosGroup/SPIRV-Tools --tags - -deploy: - provider: releases - api_key: ${spirvtoken} - on: - branch: master - condition: ${CC} == clang && ${BUILD_ANDROID_MK} != ON && ${BUILD_ANDROID_CMAKE} != ON && ${CHECK_FORMAT} != ON - file: ${TARBALL} - skip_cleanup: true - overwrite: true - -notifications: - email: - recipients: - - andreyt@google.com - - antiagainst@google.com - - awoloszyn@google.com - - dneto@google.com - - ehsann@google.com - - qining@google.com - on_success: change - on_failure: always diff --git a/3rdparty/spirv-tools/Android.mk b/3rdparty/spirv-tools/Android.mk index b775541f3..cc336a892 100644 --- a/3rdparty/spirv-tools/Android.mk +++ b/3rdparty/spirv-tools/Android.mk @@ -24,7 +24,7 @@ SPVTOOLS_SRC_FILES := \ source/table.cpp \ source/text.cpp \ source/text_handler.cpp \ - source/util/bit_stream.cpp \ + source/util/bit_vector.cpp \ source/util/parse_number.cpp \ source/util/string_utils.cpp \ source/util/timer.cpp \ @@ -33,29 +33,38 @@ SPVTOOLS_SRC_FILES := \ source/val/function.cpp \ source/val/instruction.cpp \ source/val/validation_state.cpp \ - source/validate.cpp \ - source/validate_adjacency.cpp \ - source/validate_arithmetics.cpp \ - source/validate_atomics.cpp \ - source/validate_barriers.cpp \ - source/validate_bitwise.cpp \ - source/validate_builtins.cpp \ - source/validate_capability.cpp \ - source/validate_cfg.cpp \ - source/validate_composites.cpp \ - source/validate_conversion.cpp \ - source/validate_datarules.cpp \ - source/validate_decorations.cpp \ - source/validate_derivatives.cpp \ - source/validate_ext_inst.cpp \ - source/validate_id.cpp \ - source/validate_image.cpp \ - source/validate_instruction.cpp \ - source/validate_layout.cpp \ - source/validate_literals.cpp \ - source/validate_logicals.cpp \ - source/validate_primitives.cpp \ - source/validate_type_unique.cpp + source/val/validate.cpp \ + source/val/validate_adjacency.cpp \ + source/val/validate_annotation.cpp \ + source/val/validate_arithmetics.cpp \ + source/val/validate_atomics.cpp \ + source/val/validate_barriers.cpp \ + source/val/validate_bitwise.cpp \ + source/val/validate_builtins.cpp \ + source/val/validate_capability.cpp \ + source/val/validate_cfg.cpp \ + source/val/validate_composites.cpp \ + source/val/validate_constants.cpp \ + source/val/validate_conversion.cpp \ + source/val/validate_datarules.cpp \ + source/val/validate_debug.cpp \ + source/val/validate_decorations.cpp \ + source/val/validate_derivatives.cpp \ + source/val/validate_ext_inst.cpp \ + source/val/validate_execution_limitations.cpp \ + source/val/validate_function.cpp \ + source/val/validate_id.cpp \ + source/val/validate_image.cpp \ + source/val/validate_interfaces.cpp \ + source/val/validate_instruction.cpp \ + source/val/validate_memory.cpp \ + source/val/validate_mode_setting.cpp \ + source/val/validate_layout.cpp \ + source/val/validate_literals.cpp \ + source/val/validate_logicals.cpp \ + source/val/validate_non_uniform.cpp \ + source/val/validate_primitives.cpp \ + source/val/validate_type.cpp SPVTOOLS_OPT_SRC_FILES := \ source/opt/aggressive_dead_code_elim_pass.cpp \ @@ -65,6 +74,7 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/cfg.cpp \ source/opt/cfg_cleanup_pass.cpp \ source/opt/ccp_pass.cpp \ + source/opt/combine_access_chains.cpp \ source/opt/common_uniform_elim_pass.cpp \ source/opt/compact_ids_pass.cpp \ source/opt/composite.cpp \ @@ -91,7 +101,6 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/inline_pass.cpp \ source/opt/inline_exhaustive_pass.cpp \ source/opt/inline_opaque_pass.cpp \ - source/opt/insert_extract_elim.cpp \ source/opt/instruction.cpp \ source/opt/instruction_list.cpp \ source/opt/ir_context.cpp \ @@ -102,7 +111,12 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/local_single_block_elim_pass.cpp \ source/opt/local_single_store_elim_pass.cpp \ source/opt/local_ssa_elim_pass.cpp \ + source/opt/loop_dependence.cpp \ + source/opt/loop_dependence_helpers.cpp \ source/opt/loop_descriptor.cpp \ + source/opt/loop_fission.cpp \ + source/opt/loop_fusion.cpp \ + source/opt/loop_fusion_pass.cpp \ source/opt/loop_peeling.cpp \ source/opt/loop_unroller.cpp \ source/opt/loop_unswitch_pass.cpp \ @@ -115,7 +129,9 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/pass_manager.cpp \ source/opt/private_to_local_pass.cpp \ source/opt/propagator.cpp \ + source/opt/reduce_load_size.cpp \ source/opt/redundancy_elimination.cpp \ + source/opt/register_pressure.cpp \ source/opt/remove_duplicates_pass.cpp \ source/opt/replace_invalid_opc.cpp \ source/opt/scalar_analysis.cpp \ @@ -131,6 +147,7 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/types.cpp \ source/opt/unify_const_pass.cpp \ source/opt/value_number_table.cpp \ + source/opt/vector_dce.cpp \ source/opt/workaround1209.cpp # Locations of grammar files. @@ -293,7 +310,6 @@ include $(CLEAR_VARS) LOCAL_MODULE := SPIRV-Tools LOCAL_C_INCLUDES := \ $(LOCAL_PATH)/include \ - $(LOCAL_PATH)/source \ $(LOCAL_PATH)/external/spirv-headers/include \ $(SPVTOOLS_OUT_PATH) LOCAL_EXPORT_C_INCLUDES := \ diff --git a/3rdparty/spirv-tools/BUILD.gn b/3rdparty/spirv-tools/BUILD.gn new file mode 100644 index 000000000..9fa9493df --- /dev/null +++ b/3rdparty/spirv-tools/BUILD.gn @@ -0,0 +1,756 @@ +# Copyright 2018 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import("//build_overrides/spirv_tools.gni") + +import("//testing/test.gni") +import("//build_overrides/build.gni") + +spirv_headers = spirv_tools_spirv_headers_dir + +template("spvtools_core_tables") { + assert(defined(invoker.version), "Need version in $target_name generation.") + + action("spvtools_core_tables_" + target_name) { + script = "utils/generate_grammar_tables.py" + + version = invoker.version + + core_json_file = + "${spirv_headers}/include/spirv/$version/spirv.core.grammar.json" + core_insts_file = "${target_gen_dir}/core.insts-$version.inc" + operand_kinds_file = "${target_gen_dir}/operand.kinds-$version.inc" + extinst_file = "source/extinst.debuginfo.grammar.json" + + sources = [ + core_json_file, + ] + outputs = [ + core_insts_file, + operand_kinds_file, + ] + args = [ + "--spirv-core-grammar", + rebase_path(core_json_file, root_build_dir), + "--core-insts-output", + rebase_path(core_insts_file, root_build_dir), + "--extinst-debuginfo-grammar", + rebase_path(extinst_file, root_build_dir), + "--operand-kinds-output", + rebase_path(operand_kinds_file, root_build_dir), + ] + } +} + +template("spvtools_core_enums") { + assert(defined(invoker.version), "Need version in $target_name generation.") + + action("spvtools_core_enums_" + target_name) { + script = "utils/generate_grammar_tables.py" + + version = invoker.version + + core_json_file = + "${spirv_headers}/include/spirv/$version/spirv.core.grammar.json" + debug_insts_file = "source/extinst.debuginfo.grammar.json" + extension_enum_file = "${target_gen_dir}/extension_enum.inc" + extension_map_file = "${target_gen_dir}/enum_string_mapping.inc" + + args = [ + "--spirv-core-grammar", + rebase_path(core_json_file, root_build_dir), + "--extinst-debuginfo-grammar", + rebase_path(debug_insts_file, root_build_dir), + "--extension-enum-output", + rebase_path(extension_enum_file, root_build_dir), + "--enum-string-mapping-output", + rebase_path(extension_map_file, root_build_dir), + ] + inputs = [ + core_json_file, + ] + outputs = [ + extension_enum_file, + extension_map_file, + ] + } +} + +template("spvtools_glsl_tables") { + assert(defined(invoker.version), "Need version in $target_name generation.") + + action("spvtools_glsl_tables_" + target_name) { + script = "utils/generate_grammar_tables.py" + + version = invoker.version + + core_json_file = + "${spirv_headers}/include/spirv/$version/spirv.core.grammar.json" + glsl_json_file = "${spirv_headers}/include/spirv/${version}/extinst.glsl.std.450.grammar.json" + glsl_insts_file = "${target_gen_dir}/glsl.std.450.insts.inc" + debug_insts_file = "source/extinst.debuginfo.grammar.json" + + args = [ + "--spirv-core-grammar", + rebase_path(core_json_file, root_build_dir), + "--extinst-glsl-grammar", + rebase_path(glsl_json_file, root_build_dir), + "--glsl-insts-output", + rebase_path(glsl_insts_file, root_build_dir), + "--extinst-debuginfo-grammar", + rebase_path(debug_insts_file, root_build_dir), + ] + inputs = [ + core_json_file, + glsl_json_file, + ] + outputs = [ + glsl_insts_file, + ] + } +} + +template("spvtools_opencl_tables") { + assert(defined(invoker.version), "Need version in $target_name generation.") + + action("spvtools_opencl_tables_" + target_name) { + script = "utils/generate_grammar_tables.py" + + version = invoker.version + + core_json_file = + "${spirv_headers}/include/spirv/$version/spirv.core.grammar.json" + opengl_json_file = "${spirv_headers}/include/spirv/${version}/extinst.opencl.std.100.grammar.json" + opencl_insts_file = "${target_gen_dir}/opencl.std.insts.inc" + debug_insts_file = "source/extinst.debuginfo.grammar.json" + + args = [ + "--spirv-core-grammar", + rebase_path(core_json_file, root_build_dir), + "--extinst-opencl-grammar", + rebase_path(opengl_json_file, root_build_dir), + "--opencl-insts-output", + rebase_path(opencl_insts_file, root_build_dir), + "--extinst-debuginfo-grammar", + rebase_path(debug_insts_file, root_build_dir), + ] + inputs = [ + core_json_file, + opengl_json_file, + ] + outputs = [ + opencl_insts_file, + ] + } +} + +template("spvtools_language_header") { + assert(defined(invoker.name), "Need name in $target_name generation.") + + action("spvtools_language_header_" + target_name) { + script = "utils/generate_language_headers.py" + + name = invoker.name + extinst_output_base = "${target_gen_dir}/${name}" + debug_insts_file = "source/extinst.debuginfo.grammar.json" + + args = [ + "--extinst-name", + "${name}", + "--extinst-grammar", + rebase_path(debug_insts_file, root_build_dir), + "--extinst-output-base", + rebase_path(extinst_output_base, root_build_dir), + ] + inputs = [ + debug_insts_file, + ] + outputs = [ + "${extinst_output_base}.h", + ] + } +} + +template("spvtools_vendor_table") { + assert(defined(invoker.name), "Need name in $target_name generation.") + + action("spvtools_vendor_tables_" + target_name) { + script = "utils/generate_grammar_tables.py" + + name = invoker.name + extinst_vendor_grammar = "source/extinst.${name}.grammar.json" + extinst_file = "${target_gen_dir}/${name}.insts.inc" + + args = [ + "--extinst-vendor-grammar", + rebase_path(extinst_vendor_grammar, root_build_dir), + "--vendor-insts-output", + rebase_path(extinst_file, root_build_dir), + ] + inputs = [ + extinst_vendor_grammar, + ] + outputs = [ + extinst_file, + ] + } +} + +action("spvtools_generators_inc") { + script = "utils/generate_registry_tables.py" + + # TODO(dsinclair): Make work for chrome + xml_file = "${spirv_headers}/include/spirv/spir-v.xml" + inc_file = "${target_gen_dir}/generators.inc" + + sources = [ + xml_file, + ] + outputs = [ + inc_file, + ] + args = [ + "--xml", + rebase_path(xml_file, root_build_dir), + "--generator", + rebase_path(inc_file, root_build_dir), + ] +} + +action("spvtools_build_version") { + script = "utils/update_build_version.py" + + src_dir = "." + inc_file = "${target_gen_dir}/build-version.inc" + + outputs = [ + inc_file, + ] + args = [ + rebase_path(src_dir, root_build_dir), + rebase_path(inc_file, root_build_dir), + ] +} + +spvtools_core_tables("unified1") { + version = "unified1" +} +spvtools_core_enums("unified1") { + version = "unified1" +} +spvtools_glsl_tables("glsl1-0") { + version = "1.0" +} +spvtools_opencl_tables("opencl1-0") { + version = "1.0" +} +spvtools_language_header("unified1") { + name = "DebugInfo" +} + +spvtools_vendor_tables = [ + "spv-amd-shader-explicit-vertex-parameter", + "spv-amd-shader-trinary-minmax", + "spv-amd-gcn-shader", + "spv-amd-shader-ballot", + "debuginfo", +] + +foreach(table, spvtools_vendor_tables) { + spvtools_vendor_table(table) { + name = table + } +} + +config("spvtools_config") { + include_dirs = [ + ".", + "include", + "$target_gen_dir", + "${spirv_headers}/include", + ] + + if (is_clang) { + cflags = [ "-Wno-implicit-fallthrough" ] + } +} + +static_library("spvtools") { + deps = [ + ":spvtools_core_enums_unified1", + ":spvtools_core_tables_unified1", + ":spvtools_generators_inc", + ":spvtools_glsl_tables_glsl1-0", + ":spvtools_language_header_unified1", + ":spvtools_opencl_tables_opencl1-0", + ] + foreach(target_name, spvtools_vendor_tables) { + deps += [ ":spvtools_vendor_tables_$target_name" ] + } + + sources = [ + "source/assembly_grammar.cpp", + "source/assembly_grammar.h", + "source/binary.cpp", + "source/binary.h", + "source/diagnostic.cpp", + "source/diagnostic.h", + "source/disassemble.cpp", + "source/enum_set.h", + "source/enum_string_mapping.cpp", + "source/ext_inst.cpp", + "source/ext_inst.h", + "source/extensions.cpp", + "source/extensions.h", + "source/instruction.h", + "source/libspirv.cpp", + "source/macro.h", + "source/name_mapper.cpp", + "source/name_mapper.h", + "source/opcode.cpp", + "source/opcode.h", + "source/operand.cpp", + "source/operand.h", + "source/parsed_operand.cpp", + "source/parsed_operand.h", + "source/print.cpp", + "source/print.h", + "source/spirv_constant.h", + "source/spirv_definition.h", + "source/spirv_endian.cpp", + "source/spirv_endian.h", + "source/spirv_target_env.cpp", + "source/spirv_target_env.h", + "source/spirv_validator_options.cpp", + "source/spirv_validator_options.h", + "source/table.cpp", + "source/table.h", + "source/text.cpp", + "source/text.h", + "source/text_handler.cpp", + "source/text_handler.h", + "source/util/bit_vector.cpp", + "source/util/bit_vector.h", + "source/util/bitutils.h", + "source/util/hex_float.h", + "source/util/ilist.h", + "source/util/ilist_node.h", + "source/util/make_unique.h", + "source/util/parse_number.cpp", + "source/util/parse_number.h", + "source/util/small_vector.h", + "source/util/string_utils.cpp", + "source/util/string_utils.h", + "source/util/timer.cpp", + "source/util/timer.h", + ] + + public_configs = [ ":spvtools_config" ] + configs -= [ "//build/config/compiler:chromium_code" ] + configs += [ "//build/config/compiler:no_chromium_code" ] +} + +static_library("spvtools_val") { + sources = [ + "source/val/basic_block.cpp", + "source/val/construct.cpp", + "source/val/function.cpp", + "source/val/instruction.cpp", + "source/val/validate.cpp", + "source/val/validate.h", + "source/val/validate_adjacency.cpp", + "source/val/validate_annotation.cpp", + "source/val/validate_arithmetics.cpp", + "source/val/validate_atomics.cpp", + "source/val/validate_barriers.cpp", + "source/val/validate_bitwise.cpp", + "source/val/validate_builtins.cpp", + "source/val/validate_capability.cpp", + "source/val/validate_cfg.cpp", + "source/val/validate_composites.cpp", + "source/val/validate_constants.cpp", + "source/val/validate_conversion.cpp", + "source/val/validate_datarules.cpp", + "source/val/validate_debug.cpp", + "source/val/validate_decorations.cpp", + "source/val/validate_derivatives.cpp", + "source/val/validate_execution_limitations.cpp", + "source/val/validate_ext_inst.cpp", + "source/val/validate_function.cpp", + "source/val/validate_id.cpp", + "source/val/validate_image.cpp", + "source/val/validate_instruction.cpp", + "source/val/validate_interfaces.cpp", + "source/val/validate_layout.cpp", + "source/val/validate_literals.cpp", + "source/val/validate_logicals.cpp", + "source/val/validate_memory.cpp", + "source/val/validate_mode_setting.cpp", + "source/val/validate_non_uniform.cpp", + "source/val/validate_primitives.cpp", + "source/val/validate_type.cpp", + "source/val/validation_state.cpp", + ] + + deps = [ + ":spvtools", + ] + + public_configs = [ ":spvtools_config" ] + configs -= [ "//build/config/compiler:chromium_code" ] + configs += [ "//build/config/compiler:no_chromium_code" ] +} + +static_library("spvtools_opt") { + sources = [ + "source/opt/aggressive_dead_code_elim_pass.cpp", + "source/opt/aggressive_dead_code_elim_pass.h", + "source/opt/basic_block.cpp", + "source/opt/basic_block.h", + "source/opt/block_merge_pass.cpp", + "source/opt/block_merge_pass.h", + "source/opt/build_module.cpp", + "source/opt/build_module.h", + "source/opt/ccp_pass.cpp", + "source/opt/ccp_pass.h", + "source/opt/cfg.cpp", + "source/opt/cfg.h", + "source/opt/cfg_cleanup_pass.cpp", + "source/opt/cfg_cleanup_pass.h", + "source/opt/combine_access_chains.cpp", + "source/opt/combine_access_chains.h", + "source/opt/common_uniform_elim_pass.cpp", + "source/opt/common_uniform_elim_pass.h", + "source/opt/compact_ids_pass.cpp", + "source/opt/compact_ids_pass.h", + "source/opt/composite.cpp", + "source/opt/composite.h", + "source/opt/const_folding_rules.cpp", + "source/opt/const_folding_rules.h", + "source/opt/constants.cpp", + "source/opt/constants.h", + "source/opt/copy_prop_arrays.cpp", + "source/opt/copy_prop_arrays.h", + "source/opt/dead_branch_elim_pass.cpp", + "source/opt/dead_branch_elim_pass.h", + "source/opt/dead_insert_elim_pass.cpp", + "source/opt/dead_insert_elim_pass.h", + "source/opt/dead_variable_elimination.cpp", + "source/opt/dead_variable_elimination.h", + "source/opt/decoration_manager.cpp", + "source/opt/decoration_manager.h", + "source/opt/def_use_manager.cpp", + "source/opt/def_use_manager.h", + "source/opt/dominator_analysis.cpp", + "source/opt/dominator_analysis.h", + "source/opt/dominator_tree.cpp", + "source/opt/dominator_tree.h", + "source/opt/eliminate_dead_constant_pass.cpp", + "source/opt/eliminate_dead_constant_pass.h", + "source/opt/eliminate_dead_functions_pass.cpp", + "source/opt/eliminate_dead_functions_pass.h", + "source/opt/feature_manager.cpp", + "source/opt/feature_manager.h", + "source/opt/flatten_decoration_pass.cpp", + "source/opt/flatten_decoration_pass.h", + "source/opt/fold.cpp", + "source/opt/fold.h", + "source/opt/fold_spec_constant_op_and_composite_pass.cpp", + "source/opt/fold_spec_constant_op_and_composite_pass.h", + "source/opt/folding_rules.cpp", + "source/opt/folding_rules.h", + "source/opt/freeze_spec_constant_value_pass.cpp", + "source/opt/freeze_spec_constant_value_pass.h", + "source/opt/function.cpp", + "source/opt/function.h", + "source/opt/if_conversion.cpp", + "source/opt/if_conversion.h", + "source/opt/inline_exhaustive_pass.cpp", + "source/opt/inline_exhaustive_pass.h", + "source/opt/inline_opaque_pass.cpp", + "source/opt/inline_opaque_pass.h", + "source/opt/inline_pass.cpp", + "source/opt/inline_pass.h", + "source/opt/instruction.cpp", + "source/opt/instruction.h", + "source/opt/instruction_list.cpp", + "source/opt/instruction_list.h", + "source/opt/ir_builder.h", + "source/opt/ir_context.cpp", + "source/opt/ir_context.h", + "source/opt/ir_loader.cpp", + "source/opt/ir_loader.h", + "source/opt/iterator.h", + "source/opt/licm_pass.cpp", + "source/opt/licm_pass.h", + "source/opt/local_access_chain_convert_pass.cpp", + "source/opt/local_access_chain_convert_pass.h", + "source/opt/local_redundancy_elimination.cpp", + "source/opt/local_redundancy_elimination.h", + "source/opt/local_single_block_elim_pass.cpp", + "source/opt/local_single_block_elim_pass.h", + "source/opt/local_single_store_elim_pass.cpp", + "source/opt/local_single_store_elim_pass.h", + "source/opt/local_ssa_elim_pass.cpp", + "source/opt/local_ssa_elim_pass.h", + "source/opt/log.h", + "source/opt/loop_dependence.cpp", + "source/opt/loop_dependence.h", + "source/opt/loop_dependence_helpers.cpp", + "source/opt/loop_descriptor.cpp", + "source/opt/loop_descriptor.h", + "source/opt/loop_fission.cpp", + "source/opt/loop_fission.h", + "source/opt/loop_fusion.cpp", + "source/opt/loop_fusion.h", + "source/opt/loop_fusion_pass.cpp", + "source/opt/loop_fusion_pass.h", + "source/opt/loop_peeling.cpp", + "source/opt/loop_peeling.h", + "source/opt/loop_unroller.cpp", + "source/opt/loop_unroller.h", + "source/opt/loop_unswitch_pass.cpp", + "source/opt/loop_unswitch_pass.h", + "source/opt/loop_utils.cpp", + "source/opt/loop_utils.h", + "source/opt/mem_pass.cpp", + "source/opt/mem_pass.h", + "source/opt/merge_return_pass.cpp", + "source/opt/merge_return_pass.h", + "source/opt/module.cpp", + "source/opt/module.h", + "source/opt/null_pass.h", + "source/opt/optimizer.cpp", + "source/opt/pass.cpp", + "source/opt/pass.h", + "source/opt/pass_manager.cpp", + "source/opt/pass_manager.h", + "source/opt/passes.h", + "source/opt/private_to_local_pass.cpp", + "source/opt/private_to_local_pass.h", + "source/opt/propagator.cpp", + "source/opt/propagator.h", + "source/opt/reduce_load_size.cpp", + "source/opt/reduce_load_size.h", + "source/opt/redundancy_elimination.cpp", + "source/opt/redundancy_elimination.h", + "source/opt/reflect.h", + "source/opt/register_pressure.cpp", + "source/opt/register_pressure.h", + "source/opt/remove_duplicates_pass.cpp", + "source/opt/remove_duplicates_pass.h", + "source/opt/replace_invalid_opc.cpp", + "source/opt/replace_invalid_opc.h", + "source/opt/scalar_analysis.cpp", + "source/opt/scalar_analysis.h", + "source/opt/scalar_analysis_nodes.h", + "source/opt/scalar_analysis_simplification.cpp", + "source/opt/scalar_replacement_pass.cpp", + "source/opt/scalar_replacement_pass.h", + "source/opt/set_spec_constant_default_value_pass.cpp", + "source/opt/set_spec_constant_default_value_pass.h", + "source/opt/simplification_pass.cpp", + "source/opt/simplification_pass.h", + "source/opt/ssa_rewrite_pass.cpp", + "source/opt/ssa_rewrite_pass.h", + "source/opt/strength_reduction_pass.cpp", + "source/opt/strength_reduction_pass.h", + "source/opt/strip_debug_info_pass.cpp", + "source/opt/strip_debug_info_pass.h", + "source/opt/strip_reflect_info_pass.cpp", + "source/opt/strip_reflect_info_pass.h", + "source/opt/tree_iterator.h", + "source/opt/type_manager.cpp", + "source/opt/type_manager.h", + "source/opt/types.cpp", + "source/opt/types.h", + "source/opt/unify_const_pass.cpp", + "source/opt/unify_const_pass.h", + "source/opt/value_number_table.cpp", + "source/opt/value_number_table.h", + "source/opt/vector_dce.cpp", + "source/opt/vector_dce.h", + "source/opt/workaround1209.cpp", + "source/opt/workaround1209.h", + ] + deps = [ + ":spvtools", + ] + + public_configs = [ ":spvtools_config" ] + configs -= [ "//build/config/compiler:chromium_code" ] + configs += [ "//build/config/compiler:no_chromium_code" ] +} + +group("SPIRV-Tools") { + deps = [ + ":spvtools", + ":spvtools_opt", + ":spvtools_val", + ] +} + +if (!build_with_chromium) { + googletest_dir = spirv_tools_googletest_dir + + config("gtest_config") { + include_dirs = [ + "${googletest_dir}/googletest", + "${googletest_dir}/googletest/include", + ] + } + + static_library("gtest") { + testonly = true + sources = [ + "${googletest_dir}/googletest/src/gtest-all.cc", + ] + public_configs = [ ":gtest_config" ] + } + + config("gmock_config") { + include_dirs = [ + "${googletest_dir}/googlemock", + "${googletest_dir}/googlemock/include", + "${googletest_dir}/googletest/include", + ] + if (is_clang) { + # TODO: Can remove this if/when the issue is fixed. + # https://github.com/google/googletest/issues/533 + cflags = [ "-Wno-inconsistent-missing-override" ] + } + } + + static_library("gmock") { + testonly = true + sources = [ + "${googletest_dir}/googlemock/src/gmock-all.cc", + ] + public_configs = [ ":gmock_config" ] + } +} + +config("spvtools_test_config") { + if (is_clang) { + cflags = [ "-Wno-self-assign" ] + } +} + +test("spvtools_test") { + sources = [ + "test/assembly_context_test.cpp", + "test/assembly_format_test.cpp", + "test/binary_destroy_test.cpp", + "test/binary_endianness_test.cpp", + "test/binary_header_get_test.cpp", + "test/binary_parse_test.cpp", + "test/binary_strnlen_s_test.cpp", + "test/binary_to_text.literal_test.cpp", + "test/binary_to_text_test.cpp", + "test/comment_test.cpp", + "test/enum_set_test.cpp", + "test/enum_string_mapping_test.cpp", + "test/ext_inst.debuginfo_test.cpp", + "test/ext_inst.glsl_test.cpp", + "test/ext_inst.opencl_test.cpp", + "test/fix_word_test.cpp", + "test/generator_magic_number_test.cpp", + "test/hex_float_test.cpp", + "test/immediate_int_test.cpp", + "test/libspirv_macros_test.cpp", + "test/name_mapper_test.cpp", + "test/named_id_test.cpp", + "test/opcode_make_test.cpp", + "test/opcode_require_capabilities_test.cpp", + "test/opcode_split_test.cpp", + "test/opcode_table_get_test.cpp", + "test/operand_capabilities_test.cpp", + "test/operand_pattern_test.cpp", + "test/operand_test.cpp", + "test/target_env_test.cpp", + "test/test_fixture.h", + "test/text_advance_test.cpp", + "test/text_destroy_test.cpp", + "test/text_literal_test.cpp", + "test/text_start_new_inst_test.cpp", + "test/text_to_binary.annotation_test.cpp", + "test/text_to_binary.barrier_test.cpp", + "test/text_to_binary.constant_test.cpp", + "test/text_to_binary.control_flow_test.cpp", + "test/text_to_binary.debug_test.cpp", + "test/text_to_binary.device_side_enqueue_test.cpp", + "test/text_to_binary.extension_test.cpp", + "test/text_to_binary.function_test.cpp", + "test/text_to_binary.group_test.cpp", + "test/text_to_binary.image_test.cpp", + "test/text_to_binary.literal_test.cpp", + "test/text_to_binary.memory_test.cpp", + "test/text_to_binary.misc_test.cpp", + "test/text_to_binary.mode_setting_test.cpp", + "test/text_to_binary.pipe_storage_test.cpp", + "test/text_to_binary.reserved_sampling_test.cpp", + "test/text_to_binary.subgroup_dispatch_test.cpp", + "test/text_to_binary.type_declaration_test.cpp", + "test/text_to_binary_test.cpp", + "test/text_word_get_test.cpp", + "test/unit_spirv.cpp", + "test/unit_spirv.h", + ] + + deps = [ + ":spvtools", + ":spvtools_language_header_unified1", + ":spvtools_val", + ] + + if (build_with_chromium) { + deps += [ + "//testing/gmock", + "//testing/gtest", + "//testing/gtest:gtest_main", + ] + } else { + deps += [ + ":gmock", + ":gtest", + ] + sources += [ "${googletest_dir}/googletest/src/gtest_main.cc" ] + } + + configs += [ + ":spvtools_config", + ":spvtools_test_config", + ] +} + +if (spirv_tools_standalone) { + group("fuzzers") { + testonly = true + deps = [ + "test/fuzzers", + ] + } +} + +executable("spirv-as") { + sources = [ + "source/software_version.cpp", + "tools/as/as.cpp", + ] + deps = [ + ":spvtools", + ":spvtools_build_version", + ] + configs += [ ":spvtools_config" ] +} diff --git a/3rdparty/spirv-tools/CHANGES b/3rdparty/spirv-tools/CHANGES index 9285dbfb9..ef499027f 100644 --- a/3rdparty/spirv-tools/CHANGES +++ b/3rdparty/spirv-tools/CHANGES @@ -1,6 +1,92 @@ Revision history for SPIRV-Tools -v2018.3-dev 2018-04-06 +v2018.5-dev 2018-07-08 + - General: + - Support Chromium GN build + - Use Kokoro bots: + - Disable Travis-CI bots + - Disable AppVeyor VisualStudio Release builds. Keep VS 2017 Debug build + - Don't check export symbols on OSX (Darwin): some installations don't have 'objdump' + - Reorganize source files and namespaces + - Fixes for ClangTidy, and whitespace (passes 'git cl presumit --all -uf') + - Fix unused param compile warnings/errors when Effcee not present + - Avoid including time headers when timer functionality is disabled + - #1688: Use binary mode on stdin; fixes "spirv-dis ") + if (NOT "${SPIRV_SKIP_TESTS}") + add_test(NAME spirv-tools-symbol-exports-${TARGET} + COMMAND ${PYTHON_EXECUTABLE} + ${spirv-tools_SOURCE_DIR}/utils/check_symbol_exports.py "$") + endif() endmacro() else() macro(spvtools_check_symbol_exports TARGET) - message("Skipping symbol exports test for ${TARGET}") + if (NOT "${SPIRV_SKIP_TESTS}") + message("Skipping symbol exports test for ${TARGET}") + endif() endmacro() endif() diff --git a/3rdparty/spirv-tools/CONTRIBUTING.md b/3rdparty/spirv-tools/CONTRIBUTING.md index f945f88fa..93a5610ee 100644 --- a/3rdparty/spirv-tools/CONTRIBUTING.md +++ b/3rdparty/spirv-tools/CONTRIBUTING.md @@ -87,6 +87,9 @@ usual things: * Identify potential functional problems. * Identify code duplication. * Ensure the unit tests have enough coverage. +* Ensure continuous integration (CI) bots run on the PR. If not run (in the + case of PRs by external contributors), add the "kokoro:run" label to the + pull request which will trigger running all CI jobs. When looking for functional problems, there are some common problems reviewers should pay particular attention to: diff --git a/3rdparty/spirv-tools/DEPS b/3rdparty/spirv-tools/DEPS new file mode 100644 index 000000000..b69f0320c --- /dev/null +++ b/3rdparty/spirv-tools/DEPS @@ -0,0 +1,174 @@ +use_relative_paths = True + +vars = { + 'chromium_git': 'https://chromium.googlesource.com', + 'github': 'https://github.com', + + 'build_revision': '037f38ae0fe5e11b4f7c33b750fd7a1e9634a606', + 'buildtools_revision': 'ab7b6a7b350dd15804c87c20ce78982811fdd76f', + 'clang_revision': 'abe5e4f9dc0f1df848c7a0efa05256253e77a7b7', + 'effcee_revision': '04b624799f5a9dbaf3fa1dbed2ba9dce2fc8dcf2', + 'googletest_revision': '98a0d007d7092b72eea0e501bb9ad17908a1a036', + 'testing_revision': '340252637e2e7c72c0901dcbeeacfff419e19b59', + 're2_revision': '6cf8ccd82dbaab2668e9b13596c68183c9ecd13f', + 'spirv_headers_revision': 'ff684ffc6a35d2a58f0f63108877d0064ea33feb', +} + +deps = { + "build": + Var('chromium_git') + "/chromium/src/build.git@" + Var('build_revision'), + + 'buildtools': + Var('chromium_git') + '/chromium/buildtools.git@' + + Var('buildtools_revision'), + + 'external/spirv-headers': + Var('github') + '/KhronosGroup/SPIRV-Headers.git@' + + Var('spirv_headers_revision'), + + 'external/googletest': + Var('github') + '/google/googletest.git@' + Var('googletest_revision'), + + 'external/effcee': + Var('github') + '/google/effcee.git@' + Var('effcee_revision'), + + 'external/re2': + Var('github') + '/google/re2.git@' + Var('re2_revision'), + + 'testing': + Var('chromium_git') + '/chromium/src/testing@' + + Var('testing_revision'), + + 'tools/clang': + Var('chromium_git') + '/chromium/src/tools/clang@' + Var('clang_revision') +} + +recursedeps = [ + # buildtools provides clang_format, libc++, and libc++api + 'buildtools', +] + +hooks = [ + { + 'name': 'gn_win', + 'action': [ 'download_from_google_storage', + '--no_resume', + '--platform=win32', + '--no_auth', + '--bucket', 'chromium-gn', + '-s', 'SPIRV-Tools/buildtools/win/gn.exe.sha1', + ], + }, + { + 'name': 'gn_mac', + 'pattern': '.', + 'action': [ 'download_from_google_storage', + '--no_resume', + '--platform=darwin', + '--no_auth', + '--bucket', 'chromium-gn', + '-s', 'SPIRV-Tools/buildtools/mac/gn.sha1', + ], + }, + { + 'name': 'gn_linux64', + 'pattern': '.', + 'action': [ 'download_from_google_storage', + '--no_resume', + '--platform=linux*', + '--no_auth', + '--bucket', 'chromium-gn', + '-s', 'SPIRV-Tools/buildtools/linux64/gn.sha1', + ], + }, + # Pull clang-format binaries using checked-in hashes. + { + 'name': 'clang_format_win', + 'pattern': '.', + 'action': [ 'download_from_google_storage', + '--no_resume', + '--platform=win32', + '--no_auth', + '--bucket', 'chromium-clang-format', + '-s', 'SPIRV-Tools/buildtools/win/clang-format.exe.sha1', + ], + }, + { + 'name': 'clang_format_mac', + 'pattern': '.', + 'action': [ 'download_from_google_storage', + '--no_resume', + '--platform=darwin', + '--no_auth', + '--bucket', 'chromium-clang-format', + '-s', 'SPIRV-Tools/buildtools/mac/clang-format.sha1', + ], + }, + { + 'name': 'clang_format_linux', + 'pattern': '.', + 'action': [ 'download_from_google_storage', + '--no_resume', + '--platform=linux*', + '--no_auth', + '--bucket', 'chromium-clang-format', + '-s', 'SPIRV-Tools/buildtools/linux64/clang-format.sha1', + ], + }, + { + # Pull clang + 'name': 'clang', + 'pattern': '.', + 'action': ['python', + 'SPIRV-Tools/tools/clang/scripts/update.py' + ], + }, + { + 'name': 'sysroot_arm', + 'pattern': '.', + 'condition': 'checkout_linux and checkout_arm', + 'action': ['python', 'SPIRV-Tools/build/linux/sysroot_scripts/install-sysroot.py', + '--arch=arm'], + }, + { + 'name': 'sysroot_arm64', + 'pattern': '.', + 'condition': 'checkout_linux and checkout_arm64', + 'action': ['python', 'SPIRV-Tools/build/linux/sysroot_scripts/install-sysroot.py', + '--arch=arm64'], + }, + { + 'name': 'sysroot_x86', + 'pattern': '.', + 'condition': 'checkout_linux and (checkout_x86 or checkout_x64)', + 'action': ['python', 'SPIRV-Tools/build/linux/sysroot_scripts/install-sysroot.py', + '--arch=x86'], + }, + { + 'name': 'sysroot_mips', + 'pattern': '.', + 'condition': 'checkout_linux and checkout_mips', + 'action': ['python', 'SPIRV-Tools/build/linux/sysroot_scripts/install-sysroot.py', + '--arch=mips'], + }, + { + 'name': 'sysroot_x64', + 'pattern': '.', + 'condition': 'checkout_linux and checkout_x64', + 'action': ['python', 'SPIRV-Tools/build/linux/sysroot_scripts/install-sysroot.py', + '--arch=x64'], + }, + { + # Update the Windows toolchain if necessary. + 'name': 'win_toolchain', + 'pattern': '.', + 'condition': 'checkout_win', + 'action': ['python', 'SPIRV-Tools/build/vs_toolchain.py', 'update', '--force'], + }, + { + # Update the Mac toolchain if necessary. + 'name': 'mac_toolchain', + 'pattern': '.', + 'action': ['python', 'SPIRV-Tools/build/mac_toolchain.py'], + }, +] diff --git a/3rdparty/spirv-tools/PRESUBMIT.py b/3rdparty/spirv-tools/PRESUBMIT.py new file mode 100644 index 000000000..dd3117f22 --- /dev/null +++ b/3rdparty/spirv-tools/PRESUBMIT.py @@ -0,0 +1,40 @@ +# Copyright (c) 2018 The Khronos Group Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Presubmit script for SPIRV-Tools. + +See http://dev.chromium.org/developers/how-tos/depottools/presubmit-scripts +for more details about the presubmit API built into depot_tools. +""" + +LINT_FILTERS = [ + "-build/storage_class", + "-readability/casting", + "-readability/fn_size", + "-readability/todo", + "-runtime/explicit", + "-runtime/int", + "-runtime/printf", + "-runtime/references", + "-runtime/string", +] + + +def CheckChangeOnUpload(input_api, output_api): + results = [] + results += input_api.canned_checks.CheckPatchFormatted(input_api, output_api) + results += input_api.canned_checks.CheckChangeLintsClean( + input_api, output_api, None, LINT_FILTERS) + + return results diff --git a/3rdparty/spirv-tools/README.md b/3rdparty/spirv-tools/README.md index 1c2830de3..a5c3e37ac 100644 --- a/3rdparty/spirv-tools/README.md +++ b/3rdparty/spirv-tools/README.md @@ -1,7 +1,9 @@ # SPIR-V Tools -[![Build Status](https://travis-ci.org/KhronosGroup/SPIRV-Tools.svg?branch=master)](https://travis-ci.org/KhronosGroup/SPIRV-Tools) [![Build status](https://ci.appveyor.com/api/projects/status/gpue87cesrx3pi0d/branch/master?svg=true)](https://ci.appveyor.com/project/Khronoswebmaster/spirv-tools/branch/master) +Linux![Linux Build Status](https://storage.googleapis.com/spirv-tools/badges/build_status_linux_release.svg) +MacOS![MacOS Build Status](https://storage.googleapis.com/spirv-tools/badges/build_status_macos_release.svg) +Windows![Windows Build Status](https://storage.googleapis.com/spirv-tools/badges/build_status_windows_release.svg) ## Overview @@ -274,6 +276,26 @@ via setting `SPIRV_TOOLS_EXTRA_DEFINITIONS`. For example, by setting it to `/D_ITERATOR_DEBUG_LEVEL=0` on Windows, you can disable checked iterators and iterator debugging. +### Android + +SPIR-V Tools supports building static libraries `libSPIRV-Tools.a` and +`libSPIRV-Tools-opt.a` for Android: + +``` +cd + +export ANDROID_NDK=/path/to/your/ndk + +mkdir build && cd build +mkdir libs +mkdir app + +$ANDROID_NDK/ndk-build -C ../android_test \ + NDK_PROJECT_PATH=. \ + NDK_LIBS_OUT=`pwd`/libs \ + NDK_APP_OUT=`pwd`/app +``` + ## Library ### Usage diff --git a/3rdparty/spirv-tools/build_overrides/build.gni b/3rdparty/spirv-tools/build_overrides/build.gni new file mode 100644 index 000000000..833fcd349 --- /dev/null +++ b/3rdparty/spirv-tools/build_overrides/build.gni @@ -0,0 +1,46 @@ +# Copyright 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Variable that can be used to support multiple build scenarios, like having +# Chromium specific targets in a client project's GN file etc. +build_with_chromium = false + +# Don't use Chromium's third_party/binutils. +linux_use_bundled_binutils_override = false + +declare_args() { + # Android 32-bit non-component, non-clang builds cannot have symbol_level=2 + # due to 4GiB file size limit, see https://crbug.com/648948. + # Set this flag to true to skip the assertion. + ignore_elf32_limitations = false + + # Use the system install of Xcode for tools like ibtool, libtool, etc. + # This does not affect the compiler. When this variable is false, targets will + # instead use a hermetic install of Xcode. [The hermetic install can be + # obtained with gclient sync after setting the environment variable + # FORCE_MAC_TOOLCHAIN]. + use_system_xcode = "" +} + +if (use_system_xcode == "") { + if (target_os == "mac") { + _result = exec_script("//build/mac/should_use_hermetic_xcode.py", + [ target_os ], + "value") + use_system_xcode = _result == 0 + } + if (target_os == "ios") { + use_system_xcode = true + } +} diff --git a/3rdparty/spirv-tools/build_overrides/gtest.gni b/3rdparty/spirv-tools/build_overrides/gtest.gni new file mode 100644 index 000000000..c8b1bae4c --- /dev/null +++ b/3rdparty/spirv-tools/build_overrides/gtest.gni @@ -0,0 +1,25 @@ +# Copyright 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Exclude support for registering main function in multi-process tests. +gtest_include_multiprocess = false + +# Exclude support for platform-specific operations across unit tests. +gtest_include_platform_test = false + +# Exclude support for testing Objective C code on OS X and iOS. +gtest_include_objc_support = false + +# Exclude support for flushing coverage files on iOS. +gtest_include_ios_coverage = false diff --git a/3rdparty/spirv-tools/build_overrides/spirv_tools.gni b/3rdparty/spirv-tools/build_overrides/spirv_tools.gni new file mode 100644 index 000000000..24aa033d7 --- /dev/null +++ b/3rdparty/spirv-tools/build_overrides/spirv_tools.gni @@ -0,0 +1,25 @@ +# Copyright 2018 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# These are variables that are overridable by projects that include +# SPIRV-Tools. The values in this file are the defaults for when we are +# building from SPIRV-Tools' repository. + +# Whether we are building from SPIRV-Tools' repository. +# MUST be set to false in other projects. +spirv_tools_standalone = true + +# The path to SPIRV-Tools' dependencies +spirv_tools_googletest_dir = "//external/googletest" +spirv_tools_spirv_headers_dir = "//external/spirv-headers" diff --git a/3rdparty/spirv-tools/cmake/setup_build.cmake b/3rdparty/spirv-tools/cmake/setup_build.cmake new file mode 100644 index 000000000..6ba4c53d7 --- /dev/null +++ b/3rdparty/spirv-tools/cmake/setup_build.cmake @@ -0,0 +1,20 @@ +# Find nosetests; see spirv_add_nosetests() for opting in to nosetests in a +# specific directory. +find_program(NOSETESTS_EXE NAMES nosetests PATHS $ENV{PYTHON_PACKAGE_PATH}) +if (NOT NOSETESTS_EXE) + message(STATUS "SPIRV-Tools: nosetests was not found - python support code will not be tested") +else() + message(STATUS "SPIRV-Tools: nosetests found - python support code will be tested") +endif() + +# Run nosetests on file ${PREFIX}_nosetest.py. Nosetests will look for classes +# and functions whose names start with "nosetest". The test name will be +# ${PREFIX}_nosetests. +function(spirv_add_nosetests PREFIX) + if(NOT "${SPIRV_SKIP_TESTS}" AND NOSETESTS_EXE) + add_test( + NAME ${PREFIX}_nosetests + COMMAND ${NOSETESTS_EXE} -m "^[Nn]ose[Tt]est" -v + ${CMAKE_CURRENT_SOURCE_DIR}/${PREFIX}_nosetest.py) + endif() +endfunction() diff --git a/3rdparty/spirv-tools/codereview.settings b/3rdparty/spirv-tools/codereview.settings new file mode 100644 index 000000000..ef84cf857 --- /dev/null +++ b/3rdparty/spirv-tools/codereview.settings @@ -0,0 +1,2 @@ +# This file is used by git cl to get repository specific information. +CODE_REVIEW_SERVER: github.com diff --git a/3rdparty/spirv-tools/external/SPIRV-Headers/CMakeLists.txt b/3rdparty/spirv-tools/external/SPIRV-Headers/CMakeLists.txt index a5bff172a..2488baf0a 100644 --- a/3rdparty/spirv-tools/external/SPIRV-Headers/CMakeLists.txt +++ b/3rdparty/spirv-tools/external/SPIRV-Headers/CMakeLists.txt @@ -28,7 +28,7 @@ # The SPIR-V headers from the SPIR-V Registry # https://www.khronos.org/registry/spir-v/ # -cmake_minimum_required(VERSION 2.8) +cmake_minimum_required(VERSION 2.8.11) project(SPIRV-Headers) # There are two ways to use this project. @@ -42,10 +42,19 @@ project(SPIRV-Headers) # directory. To install the headers: # 1. mkdir build ; cd build # 2. cmake .. -# 3. cmake --build . install-headers +# 3. cmake --build . --target install -file(GLOB_RECURSE FILES include/spirv/*) +file(GLOB_RECURSE HEADER_FILES + RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + include/spirv/*) +foreach(HEADER_FILE ${HEADER_FILES}) + get_filename_component(HEADER_INSTALL_DIR ${HEADER_FILE} PATH) + install(FILES ${HEADER_FILE} DESTINATION ${HEADER_INSTALL_DIR}) +endforeach() + +# legacy add_custom_target(install-headers - COMMAND cmake -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/include/spirv ${CMAKE_INSTALL_PREFIX}/include/spirv) + COMMAND cmake -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/include/spirv + $ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/include/spirv) add_subdirectory(example) diff --git a/3rdparty/spirv-tools/external/SPIRV-Headers/README.md b/3rdparty/spirv-tools/external/SPIRV-Headers/README.md index 4c3d5d49e..846b20d80 100644 --- a/3rdparty/spirv-tools/external/SPIRV-Headers/README.md +++ b/3rdparty/spirv-tools/external/SPIRV-Headers/README.md @@ -35,10 +35,7 @@ Pull requests can be made to mkdir build cd build cmake .. -# Linux -cmake --build . --target install-headers -# Windows -cmake --build . --config Debug --target install-headers +cmake --build . --target install ``` Then, for example, you will have `/usr/local/include/spirv/unified1/spirv.h` diff --git a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/spir-v.xml b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/spir-v.xml index 017615dfa..b05bfa7c4 100644 --- a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/spir-v.xml +++ b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/spir-v.xml @@ -68,7 +68,8 @@ - + + diff --git a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/OpenCL.std.h b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/OpenCL.std.h index 19a668849..fe759e1bc 100644 --- a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/OpenCL.std.h +++ b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/OpenCL.std.h @@ -24,6 +24,9 @@ ** IN THE MATERIALS. */ +#ifndef OPENCLstd_H +#define OPENCLstd_H + namespace OpenCLLIB { enum Entrypoints { @@ -208,3 +211,5 @@ enum Entrypoints { }; } // end namespace OpenCLLIB + +#endif // #ifndef OPENCLstd_H diff --git a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.core.grammar.json b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.core.grammar.json index a03c02433..cb641420d 100755 --- a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.core.grammar.json +++ b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.core.grammar.json @@ -3914,7 +3914,7 @@ { "kind" : "IdRef", "name" : "'Target'" }, { "kind" : "Decoration" } ], - "extensions" : [ "SPV_GOOGLE_decorate_string" ], + "extensions" : [ "SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1" ], "version" : "None" }, { @@ -3925,7 +3925,7 @@ { "kind" : "LiteralInteger", "name" : "'Member'" }, { "kind" : "Decoration" } ], - "extensions" : [ "SPV_GOOGLE_decorate_string" ], + "extensions" : [ "SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1" ], "version" : "None" }, { @@ -3991,6 +3991,7 @@ { "enumerant" : "ConstOffsets", "value" : "0x0020", + "capabilities" : [ "ImageGatherExtended" ], "parameters" : [ { "kind" : "IdRef" } ] @@ -5550,12 +5551,14 @@ "enumerant" : "OverrideCoverageNV", "value" : 5248, "capabilities" : [ "SampleMaskOverrideCoverageNV" ], + "extensions" : [ "SPV_NV_sample_mask_override_coverage" ], "version" : "None" }, { "enumerant" : "PassthroughNV", "value" : 5250, "capabilities" : [ "GeometryShaderPassthroughNV" ], + "extensions" : [ "SPV_NV_geometry_shader_passthrough" ], "version" : "None" }, { @@ -5568,6 +5571,7 @@ "enumerant" : "SecondaryViewportRelativeNV", "value" : 5256, "capabilities" : [ "ShaderStereoViewNV" ], + "extensions" : [ "SPV_NV_stereo_view_rendering" ], "version" : "None", "parameters" : [ { "kind" : "LiteralInteger", "name" : "'Offset'" } @@ -5960,12 +5964,14 @@ "enumerant" : "SecondaryPositionNV", "value" : 5257, "capabilities" : [ "ShaderStereoViewNV" ], + "extensions" : [ "SPV_NV_stereo_view_rendering" ], "version" : "None" }, { "enumerant" : "SecondaryViewportMaskNV", "value" : 5258, "capabilities" : [ "ShaderStereoViewNV" ], + "extensions" : [ "SPV_NV_stereo_view_rendering" ], "version" : "None" }, { @@ -6043,17 +6049,23 @@ { "enumerant" : "PartitionedReduceNV", "value" : 6, - "capabilities" : [ "GroupNonUniformPartitionedNV" ] + "capabilities" : [ "GroupNonUniformPartitionedNV" ], + "extensions" : [ "SPV_NV_shader_subgroup_partitioned" ], + "version" : "None" }, { "enumerant" : "PartitionedInclusiveScanNV", "value" : 7, - "capabilities" : [ "GroupNonUniformPartitionedNV" ] + "capabilities" : [ "GroupNonUniformPartitionedNV" ], + "extensions" : [ "SPV_NV_shader_subgroup_partitioned" ], + "version" : "None" }, { "enumerant" : "PartitionedExclusiveScanNV", "value" : 8, - "capabilities" : [ "GroupNonUniformPartitionedNV" ] + "capabilities" : [ "GroupNonUniformPartitionedNV" ], + "extensions" : [ "SPV_NV_shader_subgroup_partitioned" ], + "version" : "None" } ] }, @@ -6260,8 +6272,7 @@ }, { "enumerant" : "Int8", - "value" : 39, - "capabilities" : [ "Kernel" ] + "value" : 39 }, { "enumerant" : "InputAttachment", @@ -6518,6 +6529,25 @@ "extensions" : [ "SPV_KHR_post_depth_coverage" ], "version" : "None" }, + { + "enumerant" : "StorageBuffer8BitAccess", + "value" : 4448, + "extensions" : [ "SPV_KHR_8bit_storage" ], + "version" : "None" + }, + { + "enumerant" : "UniformAndStorageBuffer8BitAccess", + "value" : 4449, + "capabilities" : [ "StorageBuffer8BitAccess" ], + "extensions" : [ "SPV_KHR_8bit_storage" ], + "version" : "None" + }, + { + "enumerant" : "StoragePushConstant8", + "value" : 4450, + "extensions" : [ "SPV_KHR_8bit_storage" ], + "version" : "None" + }, { "enumerant" : "Float16ImageAMD", "value" : 5008, diff --git a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.h b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.h index e0a0330ba..4c90c936c 100644 --- a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.h +++ b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.h @@ -683,6 +683,9 @@ typedef enum SpvCapability_ { SpvCapabilityVariablePointers = 4442, SpvCapabilityAtomicStorageOps = 4445, SpvCapabilitySampleMaskPostDepthCoverage = 4447, + SpvCapabilityStorageBuffer8BitAccess = 4448, + SpvCapabilityUniformAndStorageBuffer8BitAccess = 4449, + SpvCapabilityStoragePushConstant8 = 4450, SpvCapabilityFloat16ImageAMD = 5008, SpvCapabilityImageGatherBiasLodAMD = 5009, SpvCapabilityFragmentMaskAMD = 5010, diff --git a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.hpp b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.hpp index e21762dbe..f16c2963e 100644 --- a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.hpp +++ b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.hpp @@ -679,6 +679,9 @@ enum Capability { CapabilityVariablePointers = 4442, CapabilityAtomicStorageOps = 4445, CapabilitySampleMaskPostDepthCoverage = 4447, + CapabilityStorageBuffer8BitAccess = 4448, + CapabilityUniformAndStorageBuffer8BitAccess = 4449, + CapabilityStoragePushConstant8 = 4450, CapabilityFloat16ImageAMD = 5008, CapabilityImageGatherBiasLodAMD = 5009, CapabilityFragmentMaskAMD = 5010, diff --git a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.hpp11 b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.hpp11 index 4956a4916..3bd5b8a0d 100644 --- a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.hpp11 +++ b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.hpp11 @@ -679,6 +679,9 @@ enum class Capability : unsigned { VariablePointers = 4442, AtomicStorageOps = 4445, SampleMaskPostDepthCoverage = 4447, + StorageBuffer8BitAccess = 4448, + UniformAndStorageBuffer8BitAccess = 4449, + StoragePushConstant8 = 4450, Float16ImageAMD = 5008, ImageGatherBiasLodAMD = 5009, FragmentMaskAMD = 5010, diff --git a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.json b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.json index 5c3480e24..a592dfa23 100644 --- a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.json +++ b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.json @@ -704,6 +704,9 @@ "VariablePointers": 4442, "AtomicStorageOps": 4445, "SampleMaskPostDepthCoverage": 4447, + "StorageBuffer8BitAccess": 4448, + "UniformAndStorageBuffer8BitAccess": 4449, + "StoragePushConstant8": 4450, "Float16ImageAMD": 5008, "ImageGatherBiasLodAMD": 5009, "FragmentMaskAMD": 5010, diff --git a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.lua b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.lua index 8a3b496da..43e9ba5be 100644 --- a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.lua +++ b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.lua @@ -641,6 +641,9 @@ spv = { VariablePointers = 4442, AtomicStorageOps = 4445, SampleMaskPostDepthCoverage = 4447, + StorageBuffer8BitAccess = 4448, + UniformAndStorageBuffer8BitAccess = 4449, + StoragePushConstant8 = 4450, Float16ImageAMD = 5008, ImageGatherBiasLodAMD = 5009, FragmentMaskAMD = 5010, diff --git a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.py b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.py index f3d698ab1..cb3775ff9 100755 --- a/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.py +++ b/3rdparty/spirv-tools/external/SPIRV-Headers/include/spirv/unified1/spirv.py @@ -641,6 +641,9 @@ spv = { 'VariablePointers' : 4442, 'AtomicStorageOps' : 4445, 'SampleMaskPostDepthCoverage' : 4447, + 'StorageBuffer8BitAccess' : 4448, + 'UniformAndStorageBuffer8BitAccess' : 4449, + 'StoragePushConstant8' : 4450, 'Float16ImageAMD' : 5008, 'ImageGatherBiasLodAMD' : 5009, 'FragmentMaskAMD' : 5010, diff --git a/3rdparty/spirv-tools/external/SPIRV-Headers/tools/buildHeaders/bin/makeHeaders b/3rdparty/spirv-tools/external/SPIRV-Headers/tools/buildHeaders/bin/makeHeaders index d022fa1eb..bf2c61515 100755 --- a/3rdparty/spirv-tools/external/SPIRV-Headers/tools/buildHeaders/bin/makeHeaders +++ b/3rdparty/spirv-tools/external/SPIRV-Headers/tools/buildHeaders/bin/makeHeaders @@ -1,4 +1,4 @@ -#!/usr/bin/bash +#!/usr/bin/env bash cd ../../include/spirv/unified1 ../../../tools/buildHeaders/build/install/bin/buildSpvHeaders -H spirv.core.grammar.json diff --git a/3rdparty/spirv-tools/external/SPIRV-Headers/tools/buildHeaders/jsonToSpirv.h b/3rdparty/spirv-tools/external/SPIRV-Headers/tools/buildHeaders/jsonToSpirv.h index b25b89ed2..00a2f70d9 100755 --- a/3rdparty/spirv-tools/external/SPIRV-Headers/tools/buildHeaders/jsonToSpirv.h +++ b/3rdparty/spirv-tools/external/SPIRV-Headers/tools/buildHeaders/jsonToSpirv.h @@ -26,6 +26,7 @@ #ifndef JSON_TO_SPIRV #define JSON_TO_SPIRV +#include #include #include #include diff --git a/3rdparty/spirv-tools/include/generated/build-version.inc b/3rdparty/spirv-tools/include/generated/build-version.inc index f6329caa0..311820b73 100644 --- a/3rdparty/spirv-tools/include/generated/build-version.inc +++ b/3rdparty/spirv-tools/include/generated/build-version.inc @@ -1 +1 @@ -"v2018.3-dev", "SPIRV-Tools v2018.3-dev v2018.2-56-g3020104" +"v2018.5-dev", "SPIRV-Tools v2018.5-dev v2018.4-149-g58e53ea" diff --git a/3rdparty/spirv-tools/include/generated/core.insts-unified1.inc b/3rdparty/spirv-tools/include/generated/core.insts-unified1.inc index 1d686309e..e7e16db6b 100644 --- a/3rdparty/spirv-tools/include/generated/core.insts-unified1.inc +++ b/3rdparty/spirv-tools/include/generated/core.insts-unified1.inc @@ -31,13 +31,13 @@ static const SpvCapability pygen_variable_caps_SubgroupImageBlockIOINTEL[] = {Sp static const SpvCapability pygen_variable_caps_SubgroupShuffleINTEL[] = {SpvCapabilitySubgroupShuffleINTEL}; static const SpvCapability pygen_variable_caps_SubgroupVoteKHR[] = {SpvCapabilitySubgroupVoteKHR}; -static const libspirv::Extension pygen_variable_exts_SPV_AMD_shader_ballot[] = {libspirv::Extension::kSPV_AMD_shader_ballot}; -static const libspirv::Extension pygen_variable_exts_SPV_AMD_shader_fragment_mask[] = {libspirv::Extension::kSPV_AMD_shader_fragment_mask}; -static const libspirv::Extension pygen_variable_exts_SPV_GOOGLE_decorate_string[] = {libspirv::Extension::kSPV_GOOGLE_decorate_string}; -static const libspirv::Extension pygen_variable_exts_SPV_GOOGLE_hlsl_functionality1[] = {libspirv::Extension::kSPV_GOOGLE_hlsl_functionality1}; -static const libspirv::Extension pygen_variable_exts_SPV_KHR_shader_ballot[] = {libspirv::Extension::kSPV_KHR_shader_ballot}; -static const libspirv::Extension pygen_variable_exts_SPV_KHR_subgroup_vote[] = {libspirv::Extension::kSPV_KHR_subgroup_vote}; -static const libspirv::Extension pygen_variable_exts_SPV_NV_shader_subgroup_partitioned[] = {libspirv::Extension::kSPV_NV_shader_subgroup_partitioned}; +static const spvtools::Extension pygen_variable_exts_SPV_AMD_shader_ballot[] = {spvtools::Extension::kSPV_AMD_shader_ballot}; +static const spvtools::Extension pygen_variable_exts_SPV_AMD_shader_fragment_mask[] = {spvtools::Extension::kSPV_AMD_shader_fragment_mask}; +static const spvtools::Extension pygen_variable_exts_SPV_GOOGLE_decorate_stringSPV_GOOGLE_hlsl_functionality1[] = {spvtools::Extension::kSPV_GOOGLE_decorate_string, spvtools::Extension::kSPV_GOOGLE_hlsl_functionality1}; +static const spvtools::Extension pygen_variable_exts_SPV_GOOGLE_hlsl_functionality1[] = {spvtools::Extension::kSPV_GOOGLE_hlsl_functionality1}; +static const spvtools::Extension pygen_variable_exts_SPV_KHR_shader_ballot[] = {spvtools::Extension::kSPV_KHR_shader_ballot}; +static const spvtools::Extension pygen_variable_exts_SPV_KHR_subgroup_vote[] = {spvtools::Extension::kSPV_KHR_subgroup_vote}; +static const spvtools::Extension pygen_variable_exts_SPV_NV_shader_subgroup_partitioned[] = {spvtools::Extension::kSPV_NV_shader_subgroup_partitioned}; static const spv_opcode_desc_t kOpcodeTableEntries[] = { {"Nop", SpvOpNop, 0, nullptr, 0, {}, 0, 0, 0, nullptr, SPV_SPIRV_VERSION_WORD(1, 0)}, @@ -405,6 +405,6 @@ static const spv_opcode_desc_t kOpcodeTableEntries[] = { {"SubgroupBlockWriteINTEL", SpvOpSubgroupBlockWriteINTEL, 1, pygen_variable_caps_SubgroupBufferBlockIOINTEL, 2, {SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID}, 0, 0, 0, nullptr, 0xffffffffu}, {"SubgroupImageBlockReadINTEL", SpvOpSubgroupImageBlockReadINTEL, 1, pygen_variable_caps_SubgroupImageBlockIOINTEL, 4, {SPV_OPERAND_TYPE_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID}, 1, 1, 0, nullptr, 0xffffffffu}, {"SubgroupImageBlockWriteINTEL", SpvOpSubgroupImageBlockWriteINTEL, 1, pygen_variable_caps_SubgroupImageBlockIOINTEL, 3, {SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID}, 0, 0, 0, nullptr, 0xffffffffu}, - {"DecorateStringGOOGLE", SpvOpDecorateStringGOOGLE, 0, nullptr, 2, {SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_DECORATION}, 0, 0, 1, pygen_variable_exts_SPV_GOOGLE_decorate_string, 0xffffffffu}, - {"MemberDecorateStringGOOGLE", SpvOpMemberDecorateStringGOOGLE, 0, nullptr, 3, {SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_LITERAL_INTEGER, SPV_OPERAND_TYPE_DECORATION}, 0, 0, 1, pygen_variable_exts_SPV_GOOGLE_decorate_string, 0xffffffffu} + {"DecorateStringGOOGLE", SpvOpDecorateStringGOOGLE, 0, nullptr, 2, {SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_DECORATION}, 0, 0, 2, pygen_variable_exts_SPV_GOOGLE_decorate_stringSPV_GOOGLE_hlsl_functionality1, 0xffffffffu}, + {"MemberDecorateStringGOOGLE", SpvOpMemberDecorateStringGOOGLE, 0, nullptr, 3, {SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_LITERAL_INTEGER, SPV_OPERAND_TYPE_DECORATION}, 0, 0, 2, pygen_variable_exts_SPV_GOOGLE_decorate_stringSPV_GOOGLE_hlsl_functionality1, 0xffffffffu} }; diff --git a/3rdparty/spirv-tools/include/generated/enum_string_mapping.inc b/3rdparty/spirv-tools/include/generated/enum_string_mapping.inc index a5b6274b2..964b142d9 100644 --- a/3rdparty/spirv-tools/include/generated/enum_string_mapping.inc +++ b/3rdparty/spirv-tools/include/generated/enum_string_mapping.inc @@ -36,6 +36,8 @@ const char* ExtensionToString(Extension extension) { return "SPV_INTEL_subgroups"; case Extension::kSPV_KHR_16bit_storage: return "SPV_KHR_16bit_storage"; + case Extension::kSPV_KHR_8bit_storage: + return "SPV_KHR_8bit_storage"; case Extension::kSPV_KHR_device_group: return "SPV_KHR_device_group"; case Extension::kSPV_KHR_multiview: @@ -75,8 +77,8 @@ const char* ExtensionToString(Extension extension) { bool GetExtensionFromString(const char* str, Extension* extension) { - static const char* known_ext_strs[] = { "SPV_AMD_gcn_shader", "SPV_AMD_gpu_shader_half_float", "SPV_AMD_gpu_shader_half_float_fetch", "SPV_AMD_gpu_shader_int16", "SPV_AMD_shader_ballot", "SPV_AMD_shader_explicit_vertex_parameter", "SPV_AMD_shader_fragment_mask", "SPV_AMD_shader_image_load_store_lod", "SPV_AMD_shader_trinary_minmax", "SPV_AMD_texture_gather_bias_lod", "SPV_EXT_descriptor_indexing", "SPV_EXT_fragment_fully_covered", "SPV_EXT_shader_stencil_export", "SPV_EXT_shader_viewport_index_layer", "SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1", "SPV_INTEL_subgroups", "SPV_KHR_16bit_storage", "SPV_KHR_device_group", "SPV_KHR_multiview", "SPV_KHR_post_depth_coverage", "SPV_KHR_shader_atomic_counter_ops", "SPV_KHR_shader_ballot", "SPV_KHR_shader_draw_parameters", "SPV_KHR_storage_buffer_storage_class", "SPV_KHR_subgroup_vote", "SPV_KHR_variable_pointers", "SPV_NVX_multiview_per_view_attributes", "SPV_NV_geometry_shader_passthrough", "SPV_NV_sample_mask_override_coverage", "SPV_NV_shader_subgroup_partitioned", "SPV_NV_stereo_view_rendering", "SPV_NV_viewport_array2", "SPV_VALIDATOR_ignore_type_decl_unique" }; - static const Extension known_ext_ids[] = { Extension::kSPV_AMD_gcn_shader, Extension::kSPV_AMD_gpu_shader_half_float, Extension::kSPV_AMD_gpu_shader_half_float_fetch, Extension::kSPV_AMD_gpu_shader_int16, Extension::kSPV_AMD_shader_ballot, Extension::kSPV_AMD_shader_explicit_vertex_parameter, Extension::kSPV_AMD_shader_fragment_mask, Extension::kSPV_AMD_shader_image_load_store_lod, Extension::kSPV_AMD_shader_trinary_minmax, Extension::kSPV_AMD_texture_gather_bias_lod, Extension::kSPV_EXT_descriptor_indexing, Extension::kSPV_EXT_fragment_fully_covered, Extension::kSPV_EXT_shader_stencil_export, Extension::kSPV_EXT_shader_viewport_index_layer, Extension::kSPV_GOOGLE_decorate_string, Extension::kSPV_GOOGLE_hlsl_functionality1, Extension::kSPV_INTEL_subgroups, Extension::kSPV_KHR_16bit_storage, Extension::kSPV_KHR_device_group, Extension::kSPV_KHR_multiview, Extension::kSPV_KHR_post_depth_coverage, Extension::kSPV_KHR_shader_atomic_counter_ops, Extension::kSPV_KHR_shader_ballot, Extension::kSPV_KHR_shader_draw_parameters, Extension::kSPV_KHR_storage_buffer_storage_class, Extension::kSPV_KHR_subgroup_vote, Extension::kSPV_KHR_variable_pointers, Extension::kSPV_NVX_multiview_per_view_attributes, Extension::kSPV_NV_geometry_shader_passthrough, Extension::kSPV_NV_sample_mask_override_coverage, Extension::kSPV_NV_shader_subgroup_partitioned, Extension::kSPV_NV_stereo_view_rendering, Extension::kSPV_NV_viewport_array2, Extension::kSPV_VALIDATOR_ignore_type_decl_unique }; + static const char* known_ext_strs[] = { "SPV_AMD_gcn_shader", "SPV_AMD_gpu_shader_half_float", "SPV_AMD_gpu_shader_half_float_fetch", "SPV_AMD_gpu_shader_int16", "SPV_AMD_shader_ballot", "SPV_AMD_shader_explicit_vertex_parameter", "SPV_AMD_shader_fragment_mask", "SPV_AMD_shader_image_load_store_lod", "SPV_AMD_shader_trinary_minmax", "SPV_AMD_texture_gather_bias_lod", "SPV_EXT_descriptor_indexing", "SPV_EXT_fragment_fully_covered", "SPV_EXT_shader_stencil_export", "SPV_EXT_shader_viewport_index_layer", "SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1", "SPV_INTEL_subgroups", "SPV_KHR_16bit_storage", "SPV_KHR_8bit_storage", "SPV_KHR_device_group", "SPV_KHR_multiview", "SPV_KHR_post_depth_coverage", "SPV_KHR_shader_atomic_counter_ops", "SPV_KHR_shader_ballot", "SPV_KHR_shader_draw_parameters", "SPV_KHR_storage_buffer_storage_class", "SPV_KHR_subgroup_vote", "SPV_KHR_variable_pointers", "SPV_NVX_multiview_per_view_attributes", "SPV_NV_geometry_shader_passthrough", "SPV_NV_sample_mask_override_coverage", "SPV_NV_shader_subgroup_partitioned", "SPV_NV_stereo_view_rendering", "SPV_NV_viewport_array2", "SPV_VALIDATOR_ignore_type_decl_unique" }; + static const Extension known_ext_ids[] = { Extension::kSPV_AMD_gcn_shader, Extension::kSPV_AMD_gpu_shader_half_float, Extension::kSPV_AMD_gpu_shader_half_float_fetch, Extension::kSPV_AMD_gpu_shader_int16, Extension::kSPV_AMD_shader_ballot, Extension::kSPV_AMD_shader_explicit_vertex_parameter, Extension::kSPV_AMD_shader_fragment_mask, Extension::kSPV_AMD_shader_image_load_store_lod, Extension::kSPV_AMD_shader_trinary_minmax, Extension::kSPV_AMD_texture_gather_bias_lod, Extension::kSPV_EXT_descriptor_indexing, Extension::kSPV_EXT_fragment_fully_covered, Extension::kSPV_EXT_shader_stencil_export, Extension::kSPV_EXT_shader_viewport_index_layer, Extension::kSPV_GOOGLE_decorate_string, Extension::kSPV_GOOGLE_hlsl_functionality1, Extension::kSPV_INTEL_subgroups, Extension::kSPV_KHR_16bit_storage, Extension::kSPV_KHR_8bit_storage, Extension::kSPV_KHR_device_group, Extension::kSPV_KHR_multiview, Extension::kSPV_KHR_post_depth_coverage, Extension::kSPV_KHR_shader_atomic_counter_ops, Extension::kSPV_KHR_shader_ballot, Extension::kSPV_KHR_shader_draw_parameters, Extension::kSPV_KHR_storage_buffer_storage_class, Extension::kSPV_KHR_subgroup_vote, Extension::kSPV_KHR_variable_pointers, Extension::kSPV_NVX_multiview_per_view_attributes, Extension::kSPV_NV_geometry_shader_passthrough, Extension::kSPV_NV_sample_mask_override_coverage, Extension::kSPV_NV_shader_subgroup_partitioned, Extension::kSPV_NV_stereo_view_rendering, Extension::kSPV_NV_viewport_array2, Extension::kSPV_VALIDATOR_ignore_type_decl_unique }; const auto b = std::begin(known_ext_strs); const auto e = std::end(known_ext_strs); const auto found = std::equal_range( @@ -252,6 +254,12 @@ const char* CapabilityToString(SpvCapability capability) { return "AtomicStorageOps"; case SpvCapabilitySampleMaskPostDepthCoverage: return "SampleMaskPostDepthCoverage"; + case SpvCapabilityStorageBuffer8BitAccess: + return "StorageBuffer8BitAccess"; + case SpvCapabilityUniformAndStorageBuffer8BitAccess: + return "UniformAndStorageBuffer8BitAccess"; + case SpvCapabilityStoragePushConstant8: + return "StoragePushConstant8"; case SpvCapabilityFloat16ImageAMD: return "Float16ImageAMD"; case SpvCapabilityImageGatherBiasLodAMD: diff --git a/3rdparty/spirv-tools/include/generated/extension_enum.inc b/3rdparty/spirv-tools/include/generated/extension_enum.inc index db2809338..c1cf94a7b 100644 --- a/3rdparty/spirv-tools/include/generated/extension_enum.inc +++ b/3rdparty/spirv-tools/include/generated/extension_enum.inc @@ -16,6 +16,7 @@ kSPV_GOOGLE_decorate_string, kSPV_GOOGLE_hlsl_functionality1, kSPV_INTEL_subgroups, kSPV_KHR_16bit_storage, +kSPV_KHR_8bit_storage, kSPV_KHR_device_group, kSPV_KHR_multiview, kSPV_KHR_post_depth_coverage, diff --git a/3rdparty/spirv-tools/include/generated/generators.inc b/3rdparty/spirv-tools/include/generated/generators.inc index c186d0cf0..39709d3aa 100644 --- a/3rdparty/spirv-tools/include/generated/generators.inc +++ b/3rdparty/spirv-tools/include/generated/generators.inc @@ -17,3 +17,4 @@ {16, "X-LEGEND", "Mesa-IR/SPIR-V Translator", "X-LEGEND Mesa-IR/SPIR-V Translator"}, {17, "Khronos", "SPIR-V Tools Linker", "Khronos SPIR-V Tools Linker"}, {18, "Wine", "VKD3D Shader Compiler", "Wine VKD3D Shader Compiler"}, +{19, "Clay", "Clay Shader Compiler", "Clay Clay Shader Compiler"}, diff --git a/3rdparty/spirv-tools/include/generated/operand.kinds-unified1.inc b/3rdparty/spirv-tools/include/generated/operand.kinds-unified1.inc index 991cfa837..e40c74b21 100644 --- a/3rdparty/spirv-tools/include/generated/operand.kinds-unified1.inc +++ b/3rdparty/spirv-tools/include/generated/operand.kinds-unified1.inc @@ -48,6 +48,7 @@ static const SpvCapability pygen_variable_caps_ShaderViewportIndexLayerNV[] = {S static const SpvCapability pygen_variable_caps_ShaderViewportMaskNV[] = {SpvCapabilityShaderViewportMaskNV}; static const SpvCapability pygen_variable_caps_StencilExportEXT[] = {SpvCapabilityStencilExportEXT}; static const SpvCapability pygen_variable_caps_StorageBuffer16BitAccessStorageUniformBufferBlock16[] = {SpvCapabilityStorageBuffer16BitAccess, SpvCapabilityStorageUniformBufferBlock16}; +static const SpvCapability pygen_variable_caps_StorageBuffer8BitAccess[] = {SpvCapabilityStorageBuffer8BitAccess}; static const SpvCapability pygen_variable_caps_StorageImageExtendedFormats[] = {SpvCapabilityStorageImageExtendedFormats}; static const SpvCapability pygen_variable_caps_SubgroupBallotKHRGroupNonUniformBallot[] = {SpvCapabilitySubgroupBallotKHR, SpvCapabilityGroupNonUniformBallot}; static const SpvCapability pygen_variable_caps_SubgroupDispatch[] = {SpvCapabilitySubgroupDispatch}; @@ -55,33 +56,34 @@ static const SpvCapability pygen_variable_caps_Tessellation[] = {SpvCapabilityTe static const SpvCapability pygen_variable_caps_TransformFeedback[] = {SpvCapabilityTransformFeedback}; static const SpvCapability pygen_variable_caps_VariablePointersStorageBuffer[] = {SpvCapabilityVariablePointersStorageBuffer}; -static const libspirv::Extension pygen_variable_exts_SPV_AMD_gpu_shader_half_float_fetch[] = {libspirv::Extension::kSPV_AMD_gpu_shader_half_float_fetch}; -static const libspirv::Extension pygen_variable_exts_SPV_AMD_shader_explicit_vertex_parameter[] = {libspirv::Extension::kSPV_AMD_shader_explicit_vertex_parameter}; -static const libspirv::Extension pygen_variable_exts_SPV_AMD_shader_fragment_mask[] = {libspirv::Extension::kSPV_AMD_shader_fragment_mask}; -static const libspirv::Extension pygen_variable_exts_SPV_AMD_shader_image_load_store_lod[] = {libspirv::Extension::kSPV_AMD_shader_image_load_store_lod}; -static const libspirv::Extension pygen_variable_exts_SPV_AMD_texture_gather_bias_lod[] = {libspirv::Extension::kSPV_AMD_texture_gather_bias_lod}; -static const libspirv::Extension pygen_variable_exts_SPV_EXT_descriptor_indexing[] = {libspirv::Extension::kSPV_EXT_descriptor_indexing}; -static const libspirv::Extension pygen_variable_exts_SPV_EXT_fragment_fully_covered[] = {libspirv::Extension::kSPV_EXT_fragment_fully_covered}; -static const libspirv::Extension pygen_variable_exts_SPV_EXT_shader_stencil_export[] = {libspirv::Extension::kSPV_EXT_shader_stencil_export}; -static const libspirv::Extension pygen_variable_exts_SPV_EXT_shader_viewport_index_layer[] = {libspirv::Extension::kSPV_EXT_shader_viewport_index_layer}; -static const libspirv::Extension pygen_variable_exts_SPV_GOOGLE_hlsl_functionality1[] = {libspirv::Extension::kSPV_GOOGLE_hlsl_functionality1}; -static const libspirv::Extension pygen_variable_exts_SPV_INTEL_subgroups[] = {libspirv::Extension::kSPV_INTEL_subgroups}; -static const libspirv::Extension pygen_variable_exts_SPV_KHR_16bit_storage[] = {libspirv::Extension::kSPV_KHR_16bit_storage}; -static const libspirv::Extension pygen_variable_exts_SPV_KHR_device_group[] = {libspirv::Extension::kSPV_KHR_device_group}; -static const libspirv::Extension pygen_variable_exts_SPV_KHR_multiview[] = {libspirv::Extension::kSPV_KHR_multiview}; -static const libspirv::Extension pygen_variable_exts_SPV_KHR_post_depth_coverage[] = {libspirv::Extension::kSPV_KHR_post_depth_coverage}; -static const libspirv::Extension pygen_variable_exts_SPV_KHR_shader_atomic_counter_ops[] = {libspirv::Extension::kSPV_KHR_shader_atomic_counter_ops}; -static const libspirv::Extension pygen_variable_exts_SPV_KHR_shader_ballot[] = {libspirv::Extension::kSPV_KHR_shader_ballot}; -static const libspirv::Extension pygen_variable_exts_SPV_KHR_shader_draw_parameters[] = {libspirv::Extension::kSPV_KHR_shader_draw_parameters}; -static const libspirv::Extension pygen_variable_exts_SPV_KHR_storage_buffer_storage_classSPV_KHR_variable_pointers[] = {libspirv::Extension::kSPV_KHR_storage_buffer_storage_class, libspirv::Extension::kSPV_KHR_variable_pointers}; -static const libspirv::Extension pygen_variable_exts_SPV_KHR_subgroup_vote[] = {libspirv::Extension::kSPV_KHR_subgroup_vote}; -static const libspirv::Extension pygen_variable_exts_SPV_KHR_variable_pointers[] = {libspirv::Extension::kSPV_KHR_variable_pointers}; -static const libspirv::Extension pygen_variable_exts_SPV_NVX_multiview_per_view_attributes[] = {libspirv::Extension::kSPV_NVX_multiview_per_view_attributes}; -static const libspirv::Extension pygen_variable_exts_SPV_NV_geometry_shader_passthrough[] = {libspirv::Extension::kSPV_NV_geometry_shader_passthrough}; -static const libspirv::Extension pygen_variable_exts_SPV_NV_sample_mask_override_coverage[] = {libspirv::Extension::kSPV_NV_sample_mask_override_coverage}; -static const libspirv::Extension pygen_variable_exts_SPV_NV_shader_subgroup_partitioned[] = {libspirv::Extension::kSPV_NV_shader_subgroup_partitioned}; -static const libspirv::Extension pygen_variable_exts_SPV_NV_stereo_view_rendering[] = {libspirv::Extension::kSPV_NV_stereo_view_rendering}; -static const libspirv::Extension pygen_variable_exts_SPV_NV_viewport_array2[] = {libspirv::Extension::kSPV_NV_viewport_array2}; +static const spvtools::Extension pygen_variable_exts_SPV_AMD_gpu_shader_half_float_fetch[] = {spvtools::Extension::kSPV_AMD_gpu_shader_half_float_fetch}; +static const spvtools::Extension pygen_variable_exts_SPV_AMD_shader_explicit_vertex_parameter[] = {spvtools::Extension::kSPV_AMD_shader_explicit_vertex_parameter}; +static const spvtools::Extension pygen_variable_exts_SPV_AMD_shader_fragment_mask[] = {spvtools::Extension::kSPV_AMD_shader_fragment_mask}; +static const spvtools::Extension pygen_variable_exts_SPV_AMD_shader_image_load_store_lod[] = {spvtools::Extension::kSPV_AMD_shader_image_load_store_lod}; +static const spvtools::Extension pygen_variable_exts_SPV_AMD_texture_gather_bias_lod[] = {spvtools::Extension::kSPV_AMD_texture_gather_bias_lod}; +static const spvtools::Extension pygen_variable_exts_SPV_EXT_descriptor_indexing[] = {spvtools::Extension::kSPV_EXT_descriptor_indexing}; +static const spvtools::Extension pygen_variable_exts_SPV_EXT_fragment_fully_covered[] = {spvtools::Extension::kSPV_EXT_fragment_fully_covered}; +static const spvtools::Extension pygen_variable_exts_SPV_EXT_shader_stencil_export[] = {spvtools::Extension::kSPV_EXT_shader_stencil_export}; +static const spvtools::Extension pygen_variable_exts_SPV_EXT_shader_viewport_index_layer[] = {spvtools::Extension::kSPV_EXT_shader_viewport_index_layer}; +static const spvtools::Extension pygen_variable_exts_SPV_GOOGLE_hlsl_functionality1[] = {spvtools::Extension::kSPV_GOOGLE_hlsl_functionality1}; +static const spvtools::Extension pygen_variable_exts_SPV_INTEL_subgroups[] = {spvtools::Extension::kSPV_INTEL_subgroups}; +static const spvtools::Extension pygen_variable_exts_SPV_KHR_16bit_storage[] = {spvtools::Extension::kSPV_KHR_16bit_storage}; +static const spvtools::Extension pygen_variable_exts_SPV_KHR_8bit_storage[] = {spvtools::Extension::kSPV_KHR_8bit_storage}; +static const spvtools::Extension pygen_variable_exts_SPV_KHR_device_group[] = {spvtools::Extension::kSPV_KHR_device_group}; +static const spvtools::Extension pygen_variable_exts_SPV_KHR_multiview[] = {spvtools::Extension::kSPV_KHR_multiview}; +static const spvtools::Extension pygen_variable_exts_SPV_KHR_post_depth_coverage[] = {spvtools::Extension::kSPV_KHR_post_depth_coverage}; +static const spvtools::Extension pygen_variable_exts_SPV_KHR_shader_atomic_counter_ops[] = {spvtools::Extension::kSPV_KHR_shader_atomic_counter_ops}; +static const spvtools::Extension pygen_variable_exts_SPV_KHR_shader_ballot[] = {spvtools::Extension::kSPV_KHR_shader_ballot}; +static const spvtools::Extension pygen_variable_exts_SPV_KHR_shader_draw_parameters[] = {spvtools::Extension::kSPV_KHR_shader_draw_parameters}; +static const spvtools::Extension pygen_variable_exts_SPV_KHR_storage_buffer_storage_classSPV_KHR_variable_pointers[] = {spvtools::Extension::kSPV_KHR_storage_buffer_storage_class, spvtools::Extension::kSPV_KHR_variable_pointers}; +static const spvtools::Extension pygen_variable_exts_SPV_KHR_subgroup_vote[] = {spvtools::Extension::kSPV_KHR_subgroup_vote}; +static const spvtools::Extension pygen_variable_exts_SPV_KHR_variable_pointers[] = {spvtools::Extension::kSPV_KHR_variable_pointers}; +static const spvtools::Extension pygen_variable_exts_SPV_NVX_multiview_per_view_attributes[] = {spvtools::Extension::kSPV_NVX_multiview_per_view_attributes}; +static const spvtools::Extension pygen_variable_exts_SPV_NV_geometry_shader_passthrough[] = {spvtools::Extension::kSPV_NV_geometry_shader_passthrough}; +static const spvtools::Extension pygen_variable_exts_SPV_NV_sample_mask_override_coverage[] = {spvtools::Extension::kSPV_NV_sample_mask_override_coverage}; +static const spvtools::Extension pygen_variable_exts_SPV_NV_shader_subgroup_partitioned[] = {spvtools::Extension::kSPV_NV_shader_subgroup_partitioned}; +static const spvtools::Extension pygen_variable_exts_SPV_NV_stereo_view_rendering[] = {spvtools::Extension::kSPV_NV_stereo_view_rendering}; +static const spvtools::Extension pygen_variable_exts_SPV_NV_viewport_array2[] = {spvtools::Extension::kSPV_NV_viewport_array2}; static const spv_operand_desc_t pygen_variable_ImageOperandsEntries[] = { {"None", 0x0000, 0, nullptr, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, @@ -90,7 +92,7 @@ static const spv_operand_desc_t pygen_variable_ImageOperandsEntries[] = { {"Grad", 0x0004, 0, nullptr, 0, nullptr, {SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID}, SPV_SPIRV_VERSION_WORD(1, 0)}, {"ConstOffset", 0x0008, 0, nullptr, 0, nullptr, {SPV_OPERAND_TYPE_ID}, SPV_SPIRV_VERSION_WORD(1, 0)}, {"Offset", 0x0010, 1, pygen_variable_caps_ImageGatherExtended, 0, nullptr, {SPV_OPERAND_TYPE_ID}, SPV_SPIRV_VERSION_WORD(1, 0)}, - {"ConstOffsets", 0x0020, 0, nullptr, 0, nullptr, {SPV_OPERAND_TYPE_ID}, SPV_SPIRV_VERSION_WORD(1, 0)}, + {"ConstOffsets", 0x0020, 1, pygen_variable_caps_ImageGatherExtended, 0, nullptr, {SPV_OPERAND_TYPE_ID}, SPV_SPIRV_VERSION_WORD(1, 0)}, {"Sample", 0x0040, 0, nullptr, 0, nullptr, {SPV_OPERAND_TYPE_ID}, SPV_SPIRV_VERSION_WORD(1, 0)}, {"MinLod", 0x0080, 1, pygen_variable_caps_MinLod, 0, nullptr, {SPV_OPERAND_TYPE_ID}, SPV_SPIRV_VERSION_WORD(1, 0)} }; @@ -429,10 +431,10 @@ static const spv_operand_desc_t pygen_variable_DecorationEntries[] = { {"AlignmentId", 46, 1, pygen_variable_caps_Kernel, 0, nullptr, {SPV_OPERAND_TYPE_ID}, SPV_SPIRV_VERSION_WORD(1,2)}, {"MaxByteOffsetId", 47, 1, pygen_variable_caps_Addresses, 0, nullptr, {SPV_OPERAND_TYPE_ID}, SPV_SPIRV_VERSION_WORD(1,2)}, {"ExplicitInterpAMD", 4999, 0, nullptr, 1, pygen_variable_exts_SPV_AMD_shader_explicit_vertex_parameter, {}, 0xffffffffu}, - {"OverrideCoverageNV", 5248, 1, pygen_variable_caps_SampleMaskOverrideCoverageNV, 0, nullptr, {}, 0xffffffffu}, - {"PassthroughNV", 5250, 1, pygen_variable_caps_GeometryShaderPassthroughNV, 0, nullptr, {}, 0xffffffffu}, + {"OverrideCoverageNV", 5248, 1, pygen_variable_caps_SampleMaskOverrideCoverageNV, 1, pygen_variable_exts_SPV_NV_sample_mask_override_coverage, {}, 0xffffffffu}, + {"PassthroughNV", 5250, 1, pygen_variable_caps_GeometryShaderPassthroughNV, 1, pygen_variable_exts_SPV_NV_geometry_shader_passthrough, {}, 0xffffffffu}, {"ViewportRelativeNV", 5252, 1, pygen_variable_caps_ShaderViewportMaskNV, 0, nullptr, {}, 0xffffffffu}, - {"SecondaryViewportRelativeNV", 5256, 1, pygen_variable_caps_ShaderStereoViewNV, 0, nullptr, {SPV_OPERAND_TYPE_LITERAL_INTEGER}, 0xffffffffu}, + {"SecondaryViewportRelativeNV", 5256, 1, pygen_variable_caps_ShaderStereoViewNV, 1, pygen_variable_exts_SPV_NV_stereo_view_rendering, {SPV_OPERAND_TYPE_LITERAL_INTEGER}, 0xffffffffu}, {"NonUniformEXT", 5300, 1, pygen_variable_caps_ShaderNonUniformEXT, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, {"HlslCounterBufferGOOGLE", 5634, 0, nullptr, 1, pygen_variable_exts_SPV_GOOGLE_hlsl_functionality1, {SPV_OPERAND_TYPE_ID}, 0xffffffffu}, {"HlslSemanticGOOGLE", 5635, 0, nullptr, 1, pygen_variable_exts_SPV_GOOGLE_hlsl_functionality1, {SPV_OPERAND_TYPE_LITERAL_STRING}, 0xffffffffu} @@ -504,8 +506,8 @@ static const spv_operand_desc_t pygen_variable_BuiltInEntries[] = { {"BaryCoordPullModelAMD", 4998, 0, nullptr, 1, pygen_variable_exts_SPV_AMD_shader_explicit_vertex_parameter, {}, 0xffffffffu}, {"FragStencilRefEXT", 5014, 1, pygen_variable_caps_StencilExportEXT, 1, pygen_variable_exts_SPV_EXT_shader_stencil_export, {}, 0xffffffffu}, {"ViewportMaskNV", 5253, 1, pygen_variable_caps_ShaderViewportMaskNV, 0, nullptr, {}, 0xffffffffu}, - {"SecondaryPositionNV", 5257, 1, pygen_variable_caps_ShaderStereoViewNV, 0, nullptr, {}, 0xffffffffu}, - {"SecondaryViewportMaskNV", 5258, 1, pygen_variable_caps_ShaderStereoViewNV, 0, nullptr, {}, 0xffffffffu}, + {"SecondaryPositionNV", 5257, 1, pygen_variable_caps_ShaderStereoViewNV, 1, pygen_variable_exts_SPV_NV_stereo_view_rendering, {}, 0xffffffffu}, + {"SecondaryViewportMaskNV", 5258, 1, pygen_variable_caps_ShaderStereoViewNV, 1, pygen_variable_exts_SPV_NV_stereo_view_rendering, {}, 0xffffffffu}, {"PositionPerViewNV", 5261, 1, pygen_variable_caps_PerViewAttributesNV, 0, nullptr, {}, 0xffffffffu}, {"ViewportMaskPerViewNV", 5262, 1, pygen_variable_caps_PerViewAttributesNV, 0, nullptr, {}, 0xffffffffu}, {"FullyCoveredEXT", 5264, 1, pygen_variable_caps_FragmentFullyCoveredEXT, 1, pygen_variable_exts_SPV_EXT_fragment_fully_covered, {}, 0xffffffffu} @@ -524,9 +526,9 @@ static const spv_operand_desc_t pygen_variable_GroupOperationEntries[] = { {"InclusiveScan", 1, 3, pygen_variable_caps_KernelGroupNonUniformArithmeticGroupNonUniformBallot, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, {"ExclusiveScan", 2, 3, pygen_variable_caps_KernelGroupNonUniformArithmeticGroupNonUniformBallot, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, {"ClusteredReduce", 3, 1, pygen_variable_caps_GroupNonUniformClustered, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1,3)}, - {"PartitionedReduceNV", 6, 1, pygen_variable_caps_GroupNonUniformPartitionedNV, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, - {"PartitionedInclusiveScanNV", 7, 1, pygen_variable_caps_GroupNonUniformPartitionedNV, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, - {"PartitionedExclusiveScanNV", 8, 1, pygen_variable_caps_GroupNonUniformPartitionedNV, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)} + {"PartitionedReduceNV", 6, 1, pygen_variable_caps_GroupNonUniformPartitionedNV, 1, pygen_variable_exts_SPV_NV_shader_subgroup_partitioned, {}, 0xffffffffu}, + {"PartitionedInclusiveScanNV", 7, 1, pygen_variable_caps_GroupNonUniformPartitionedNV, 1, pygen_variable_exts_SPV_NV_shader_subgroup_partitioned, {}, 0xffffffffu}, + {"PartitionedExclusiveScanNV", 8, 1, pygen_variable_caps_GroupNonUniformPartitionedNV, 1, pygen_variable_exts_SPV_NV_shader_subgroup_partitioned, {}, 0xffffffffu} }; static const spv_operand_desc_t pygen_variable_KernelEnqueueFlagsEntries[] = { @@ -573,7 +575,7 @@ static const spv_operand_desc_t pygen_variable_CapabilityEntries[] = { {"ImageRect", 36, 1, pygen_variable_caps_SampledRect, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, {"SampledRect", 37, 1, pygen_variable_caps_Shader, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, {"GenericPointer", 38, 1, pygen_variable_caps_Addresses, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, - {"Int8", 39, 1, pygen_variable_caps_Kernel, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, + {"Int8", 39, 0, nullptr, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, {"InputAttachment", 40, 1, pygen_variable_caps_Shader, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, {"SparseResidency", 41, 1, pygen_variable_caps_Shader, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, {"MinLod", 42, 1, pygen_variable_caps_Shader, 0, nullptr, {}, SPV_SPIRV_VERSION_WORD(1, 0)}, @@ -618,6 +620,9 @@ static const spv_operand_desc_t pygen_variable_CapabilityEntries[] = { {"VariablePointers", 4442, 1, pygen_variable_caps_VariablePointersStorageBuffer, 1, pygen_variable_exts_SPV_KHR_variable_pointers, {}, SPV_SPIRV_VERSION_WORD(1,3)}, {"AtomicStorageOps", 4445, 0, nullptr, 1, pygen_variable_exts_SPV_KHR_shader_atomic_counter_ops, {}, 0xffffffffu}, {"SampleMaskPostDepthCoverage", 4447, 0, nullptr, 1, pygen_variable_exts_SPV_KHR_post_depth_coverage, {}, 0xffffffffu}, + {"StorageBuffer8BitAccess", 4448, 0, nullptr, 1, pygen_variable_exts_SPV_KHR_8bit_storage, {}, 0xffffffffu}, + {"UniformAndStorageBuffer8BitAccess", 4449, 1, pygen_variable_caps_StorageBuffer8BitAccess, 1, pygen_variable_exts_SPV_KHR_8bit_storage, {}, 0xffffffffu}, + {"StoragePushConstant8", 4450, 0, nullptr, 1, pygen_variable_exts_SPV_KHR_8bit_storage, {}, 0xffffffffu}, {"Float16ImageAMD", 5008, 1, pygen_variable_caps_Shader, 1, pygen_variable_exts_SPV_AMD_gpu_shader_half_float_fetch, {}, 0xffffffffu}, {"ImageGatherBiasLodAMD", 5009, 1, pygen_variable_caps_Shader, 1, pygen_variable_exts_SPV_AMD_texture_gather_bias_lod, {}, 0xffffffffu}, {"FragmentMaskAMD", 5010, 1, pygen_variable_caps_Shader, 1, pygen_variable_exts_SPV_AMD_shader_fragment_mask, {}, 0xffffffffu}, diff --git a/3rdparty/spirv-tools/include/spirv-tools/libspirv.h b/3rdparty/spirv-tools/include/spirv-tools/libspirv.h index d6cb60c03..a7e1b3007 100644 --- a/3rdparty/spirv-tools/include/spirv-tools/libspirv.h +++ b/3rdparty/spirv-tools/include/spirv-tools/libspirv.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SPIRV_TOOLS_LIBSPIRV_H_ -#define SPIRV_TOOLS_LIBSPIRV_H_ +#ifndef INCLUDE_SPIRV_TOOLS_LIBSPIRV_H_ +#define INCLUDE_SPIRV_TOOLS_LIBSPIRV_H_ #ifdef __cplusplus extern "C" { @@ -412,6 +412,7 @@ typedef enum { SPV_ENV_OPENCL_EMBEDDED_2_2, // OpenCL Embedded Profile 2.2 latest revision. SPV_ENV_UNIVERSAL_1_3, // SPIR-V 1.3 latest revision, no other restrictions. SPV_ENV_VULKAN_1_1, // Vulkan 1.1 latest revision. + SPV_ENV_WEBGPU_0, // Work in progress WebGPU 1.0. } spv_target_env; // SPIR-V Validator can be parameterized with the following Universal Limits. @@ -472,6 +473,18 @@ SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetRelaxStoreStruct( SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetRelaxLogicalPointer( spv_validator_options options, bool val); +// Records whether or not the validator should relax the rules on block layout. +// +// When relaxed, it will enable VK_KHR_relaxed_block_layout when validating +// standard uniform/storage block layout. +SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetRelaxBlockLayout( + spv_validator_options options, bool val); + +// Records whether or not the validator should skip validating standard +// uniform/storage block layout. +SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetSkipBlockLayout( + spv_validator_options options, bool val); + // Encodes the given SPIR-V assembly text to its binary representation. The // length parameter specifies the number of bytes for text. Encoded binary will // be stored into *binary. Any error will be written into *diagnostic if @@ -583,4 +596,4 @@ SPIRV_TOOLS_EXPORT spv_result_t spvBinaryParse( } #endif -#endif // SPIRV_TOOLS_LIBSPIRV_H_ +#endif // INCLUDE_SPIRV_TOOLS_LIBSPIRV_H_ diff --git a/3rdparty/spirv-tools/include/spirv-tools/libspirv.hpp b/3rdparty/spirv-tools/include/spirv-tools/libspirv.hpp index 2e4aa628d..b6ae38c10 100644 --- a/3rdparty/spirv-tools/include/spirv-tools/libspirv.hpp +++ b/3rdparty/spirv-tools/include/spirv-tools/libspirv.hpp @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SPIRV_TOOLS_LIBSPIRV_HPP_ -#define SPIRV_TOOLS_LIBSPIRV_HPP_ +#ifndef INCLUDE_SPIRV_TOOLS_LIBSPIRV_HPP_ +#define INCLUDE_SPIRV_TOOLS_LIBSPIRV_HPP_ #include #include @@ -81,6 +81,17 @@ class ValidatorOptions { spvValidatorOptionsSetRelaxStoreStruct(options_, val); } + // Enables VK_KHR_relaxed_block_layout when validating standard + // uniform/storage buffer layout. + void SetRelaxBlockLayout(bool val) { + spvValidatorOptionsSetRelaxBlockLayout(options_, val); + } + + // Skips validating standard uniform/storage buffer layout. + void SetSkipBlockLayout(bool val) { + spvValidatorOptionsSetSkipBlockLayout(options_, val); + } + // Records whether or not the validator should relax the rules on pointer // usage in logical addressing mode. // @@ -171,4 +182,4 @@ class SpirvTools { } // namespace spvtools -#endif // SPIRV_TOOLS_LIBSPIRV_HPP_ +#endif // INCLUDE_SPIRV_TOOLS_LIBSPIRV_HPP_ diff --git a/3rdparty/spirv-tools/include/spirv-tools/linker.hpp b/3rdparty/spirv-tools/include/spirv-tools/linker.hpp index cce78a445..d2f3e72ca 100644 --- a/3rdparty/spirv-tools/include/spirv-tools/linker.hpp +++ b/3rdparty/spirv-tools/include/spirv-tools/linker.hpp @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SPIRV_TOOLS_LINKER_HPP_ -#define SPIRV_TOOLS_LINKER_HPP_ +#ifndef INCLUDE_SPIRV_TOOLS_LINKER_HPP_ +#define INCLUDE_SPIRV_TOOLS_LINKER_HPP_ #include @@ -94,4 +94,4 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries, } // namespace spvtools -#endif // SPIRV_TOOLS_LINKER_HPP_ +#endif // INCLUDE_SPIRV_TOOLS_LINKER_HPP_ diff --git a/3rdparty/spirv-tools/include/spirv-tools/optimizer.hpp b/3rdparty/spirv-tools/include/spirv-tools/optimizer.hpp index 3a03988e2..4364d9ff5 100644 --- a/3rdparty/spirv-tools/include/spirv-tools/optimizer.hpp +++ b/3rdparty/spirv-tools/include/spirv-tools/optimizer.hpp @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SPIRV_TOOLS_OPTIMIZER_HPP_ -#define SPIRV_TOOLS_OPTIMIZER_HPP_ +#ifndef INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_ +#define INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_ #include #include @@ -25,6 +25,10 @@ namespace spvtools { +namespace opt { +class Pass; +} + // C++ interface for SPIR-V optimization functionalities. It wraps the context // (including target environment and the corresponding SPIR-V grammar) and // provides methods for registering optimization passes and optimizing. @@ -41,6 +45,12 @@ class Optimizer { PassToken(std::unique_ptr); + // Tokens for built-in passes should be created using Create*Pass functions + // below; for out-of-tree passes, use this constructor instead. + // Note that this API isn't guaranteed to be stable and may change without + // preserving source or binary compatibility in the future. + PassToken(std::unique_ptr&& pass); + // Tokens can only be moved. Copying is disabled. PassToken(const PassToken&) = delete; PassToken(PassToken&&); @@ -73,6 +83,9 @@ class Optimizer { // invoked once for each message communicated from the library. void SetMessageConsumer(MessageConsumer consumer); + // Returns a reference to the registered message consumer. + const MessageConsumer& consumer() const; + // Registers the given |pass| to this optimizer. Passes will be run in the // exact order of registration. The token passed in will be consumed by this // method. @@ -90,25 +103,71 @@ class Optimizer { // Registers passes that attempt to legalize the generated code. // - // Note: this recipe is specially for legalizing SPIR-V. It should be used - // by compilers after translating HLSL source code literally. It should + // Note: this recipe is specially designed for legalizing SPIR-V. It should be + // used by compilers after translating HLSL source code literally. It should // *not* be used by general workloads for performance or size improvement. // // This sequence of passes is subject to constant review and will change // from time to time. Optimizer& RegisterLegalizationPasses(); + // Register passes specified in the list of |flags|. Each flag must be a + // string of a form accepted by Optimizer::FlagHasValidForm(). + // + // If the list of flags contains an invalid entry, it returns false and an + // error message is emitted to the MessageConsumer object (use + // Optimizer::SetMessageConsumer to define a message consumer, if needed). + // + // If all the passes are registered successfully, it returns true. + bool RegisterPassesFromFlags(const std::vector& flags); + + // Registers the optimization pass associated with |flag|. This only accepts + // |flag| values of the form "--pass_name[=pass_args]". If no such pass + // exists, it returns false. Otherwise, the pass is registered and it returns + // true. + // + // The following flags have special meaning: + // + // -O: Registers all performance optimization passes + // (Optimizer::RegisterPerformancePasses) + // + // -Os: Registers all size optimization passes + // (Optimizer::RegisterSizePasses). + // + // --legalize-hlsl: Registers all passes that legalize SPIR-V generated by an + // HLSL front-end. + bool RegisterPassFromFlag(const std::string& flag); + + // Validates that |flag| has a valid format. Strings accepted: + // + // --pass_name[=pass_args] + // -O + // -Os + // + // If |flag| takes one of the forms above, it returns true. Otherwise, it + // returns false. + bool FlagHasValidForm(const std::string& flag) const; + // Optimizes the given SPIR-V module |original_binary| and writes the // optimized binary into |optimized_binary|. // Returns true on successful optimization, whether or not the module is - // modified. Returns false if errors occur when processing |original_binary| - // using any of the registered passes. In that case, no further passes are - // executed and the contents in |optimized_binary| may be invalid. + // modified. Returns false if |original_binary| fails to validate or if errors + // occur when processing |original_binary| using any of the registered passes. + // In that case, no further passes are executed and the contents in + // |optimized_binary| may be invalid. // // It's allowed to alias |original_binary| to the start of |optimized_binary|. bool Run(const uint32_t* original_binary, size_t original_binary_size, std::vector* optimized_binary) const; + // Same as above, except passes |options| to the validator when trying to + // validate the binary. If |skip_validation| is true, then the caller is + // guaranteeing that |original_binary| is valid, and the validator will not + // be run. + bool Run(const uint32_t* original_binary, const size_t original_binary_size, + std::vector* optimized_binary, + const ValidatorOptions& options, bool skip_validation = false) const; + // Returns a vector of strings with all the pass names added to this // optimizer's pass manager. These strings are valid until the associated // pass manager is destroyed. @@ -483,6 +542,25 @@ Optimizer::PassToken CreateLocalRedundancyEliminationPass(); // the loops preheader. Optimizer::PassToken CreateLoopInvariantCodeMotionPass(); +// Creates a loop fission pass. +// This pass will split all top level loops whose register pressure exceedes the +// given |threshold|. +Optimizer::PassToken CreateLoopFissionPass(size_t threshold); + +// Creates a loop fusion pass. +// This pass will look for adjacent loops that are compatible and legal to be +// fused. The fuse all such loops as long as the register usage for the fused +// loop stays under the threshold defined by |max_registers_per_loop|. +Optimizer::PassToken CreateLoopFusionPass(size_t max_registers_per_loop); + +// Creates a loop peeling pass. +// This pass will look for conditions inside a loop that are true or false only +// for the N first or last iteration. For loop with such condition, those N +// iterations of the loop will be executed outside of the main loop. +// To limit code size explosion, the loop peeling can only happen if the code +// size growth for each loop is under |code_growth_threshold|. +Optimizer::PassToken CreateLoopPeelingPass(); + // Creates a loop unswitch pass. // This pass will look for loop independent branch conditions and move the // condition out of the loop and version the loop based on the taken branch. @@ -496,8 +574,10 @@ Optimizer::PassToken CreateRedundancyEliminationPass(); // Create scalar replacement pass. // This pass replaces composite function scope variables with variables for each -// element if those elements are accessed individually. -Optimizer::PassToken CreateScalarReplacementPass(); +// element if those elements are accessed individually. The parameter is a +// limit on the number of members in the composite variable that the pass will +// consider replacing. +Optimizer::PassToken CreateScalarReplacementPass(uint32_t size_limit = 100); // Create a private to local pass. // This pass looks for variables delcared in the private storage class that are @@ -552,6 +632,24 @@ Optimizer::PassToken CreateSSARewritePass(); // This pass looks to copy propagate memory references for arrays. It looks // for specific code patterns to recognize array copies. Optimizer::PassToken CreateCopyPropagateArraysPass(); + +// Create a vector dce pass. +// This pass looks for components of vectors that are unused, and removes them +// from the vector. Note this would still leave around lots of dead code that +// a pass of ADCE will be able to remove. +Optimizer::PassToken CreateVectorDCEPass(); + +// Create a pass to reduce the size of loads. +// This pass looks for loads of structures where only a few of its members are +// used. It replaces the loads feeding an OpExtract with an OpAccessChain and +// a load of the specific elements. +Optimizer::PassToken CreateReduceLoadSizePass(); + +// Create a pass to combine chained access chains. +// This pass looks for access chains fed by other access chains and combines +// them into a single instruction where possible. +Optimizer::PassToken CreateCombineAccessChainsPass(); + } // namespace spvtools -#endif // SPIRV_TOOLS_OPTIMIZER_HPP_ +#endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_ diff --git a/3rdparty/spirv-tools/kokoro/android/build.sh b/3rdparty/spirv-tools/kokoro/android/build.sh new file mode 100644 index 000000000..e31744fd1 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/android/build.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Android Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +BUILD_ROOT=$PWD +SRC=$PWD/github/SPIRV-Tools +TARGET_ARCH="armeabi-v7a with NEON" +export ANDROID_NDK=/opt/android-ndk-r15c + +# Get NINJA. +wget -q https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip +unzip -q ninja-linux.zip +export PATH="$PWD:$PATH" +git clone --depth=1 https://github.com/taka-no-me/android-cmake.git android-cmake +export TOOLCHAIN_PATH=$PWD/android-cmake/android.toolchain.cmake + + +cd $SRC +git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers +git clone --depth=1 https://github.com/google/googletest external/googletest +git clone --depth=1 https://github.com/google/effcee external/effcee +git clone --depth=1 https://github.com/google/re2 external/re2 + +mkdir build && cd $SRC/build + +# Invoke the build. +BUILD_SHA=${KOKORO_GITHUB_COMMIT:-$KOKORO_GITHUB_PULL_REQUEST_COMMIT} +echo $(date): Starting build... +cmake -DCMAKE_BUILD_TYPE=Release -DANDROID_NATIVE_API_LEVEL=android-14 -DANDROID_ABI="armeabi-v7a with NEON" -DSPIRV_BUILD_COMPRESSION=ON -DSPIRV_SKIP_TESTS=ON -DCMAKE_TOOLCHAIN_FILE=$TOOLCHAIN_PATH -GNinja -DANDROID_NDK=$ANDROID_NDK .. + +echo $(date): Build everything... +ninja +echo $(date): Build completed. diff --git a/3rdparty/spirv-tools/kokoro/android/continuous.cfg b/3rdparty/spirv-tools/kokoro/android/continuous.cfg new file mode 100644 index 000000000..3bdb17a57 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/android/continuous.cfg @@ -0,0 +1,17 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +# +build_file: "SPIRV-Tools/kokoro/android/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/android/presubmit.cfg b/3rdparty/spirv-tools/kokoro/android/presubmit.cfg new file mode 100644 index 000000000..21589ccc1 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/android/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/android/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/check-format/build.sh b/3rdparty/spirv-tools/kokoro/check-format/build.sh new file mode 100644 index 000000000..2a8d50fb5 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/check-format/build.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Android Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +BUILD_ROOT=$PWD +SRC=$PWD/github/SPIRV-Tools + +# Get clang-format-5.0.0. +# Once kokoro upgrades the Ubuntu VMs, we can use 'apt-get install clang-format' +curl -L http://releases.llvm.org/5.0.0/clang+llvm-5.0.0-linux-x86_64-ubuntu14.04.tar.xz -o clang-llvm.tar.xz +tar xf clang-llvm.tar.xz +export PATH=$PWD/clang+llvm-5.0.0-linux-x86_64-ubuntu14.04/bin:$PATH + +cd $SRC +git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers +git clone --depth=1 https://github.com/google/googletest external/googletest +git clone --depth=1 https://github.com/google/effcee external/effcee +git clone --depth=1 https://github.com/google/re2 external/re2 +curl -L http://llvm.org/svn/llvm-project/cfe/trunk/tools/clang-format/clang-format-diff.py -o utils/clang-format-diff.py; + +echo $(date): Check formatting... +./utils/check_code_format.sh; +echo $(date): check completed. diff --git a/3rdparty/spirv-tools/kokoro/check-format/presubmit_check_format.cfg b/3rdparty/spirv-tools/kokoro/check-format/presubmit_check_format.cfg new file mode 100644 index 000000000..1993289d6 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/check-format/presubmit_check_format.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/check-format/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/img/linux.png b/3rdparty/spirv-tools/kokoro/img/linux.png new file mode 100644 index 000000000..ff066d979 Binary files /dev/null and b/3rdparty/spirv-tools/kokoro/img/linux.png differ diff --git a/3rdparty/spirv-tools/kokoro/img/macos.png b/3rdparty/spirv-tools/kokoro/img/macos.png new file mode 100644 index 000000000..d1349c092 Binary files /dev/null and b/3rdparty/spirv-tools/kokoro/img/macos.png differ diff --git a/3rdparty/spirv-tools/kokoro/img/windows.png b/3rdparty/spirv-tools/kokoro/img/windows.png new file mode 100644 index 000000000..a37846999 Binary files /dev/null and b/3rdparty/spirv-tools/kokoro/img/windows.png differ diff --git a/3rdparty/spirv-tools/kokoro/linux-clang-debug/build.sh b/3rdparty/spirv-tools/kokoro/linux-clang-debug/build.sh new file mode 100644 index 000000000..11b2968a6 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/linux-clang-debug/build.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Linux Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +SCRIPT_DIR=`dirname "$BASH_SOURCE"` +source $SCRIPT_DIR/../scripts/linux/build.sh DEBUG clang diff --git a/3rdparty/spirv-tools/kokoro/linux-clang-debug/continuous.cfg b/3rdparty/spirv-tools/kokoro/linux-clang-debug/continuous.cfg new file mode 100644 index 000000000..e92f059ed --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/linux-clang-debug/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/linux-clang-debug/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/linux-clang-debug/presubmit.cfg b/3rdparty/spirv-tools/kokoro/linux-clang-debug/presubmit.cfg new file mode 100644 index 000000000..5011b445e --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/linux-clang-debug/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/linux-clang-debug/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/linux-clang-release/build.sh b/3rdparty/spirv-tools/kokoro/linux-clang-release/build.sh new file mode 100644 index 000000000..476433171 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/linux-clang-release/build.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Linux Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +SCRIPT_DIR=`dirname "$BASH_SOURCE"` +source $SCRIPT_DIR/../scripts/linux/build.sh RELEASE clang diff --git a/3rdparty/spirv-tools/kokoro/linux-clang-release/continuous.cfg b/3rdparty/spirv-tools/kokoro/linux-clang-release/continuous.cfg new file mode 100644 index 000000000..687434acc --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/linux-clang-release/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/linux-clang-release/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/linux-clang-release/presubmit.cfg b/3rdparty/spirv-tools/kokoro/linux-clang-release/presubmit.cfg new file mode 100644 index 000000000..b7b9b5594 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/linux-clang-release/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/linux-clang-release/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/linux-gcc-debug/build.sh b/3rdparty/spirv-tools/kokoro/linux-gcc-debug/build.sh new file mode 100644 index 000000000..3ef1e251b --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/linux-gcc-debug/build.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Linux Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +SCRIPT_DIR=`dirname "$BASH_SOURCE"` +source $SCRIPT_DIR/../scripts/linux/build.sh DEBUG gcc diff --git a/3rdparty/spirv-tools/kokoro/linux-gcc-debug/continuous.cfg b/3rdparty/spirv-tools/kokoro/linux-gcc-debug/continuous.cfg new file mode 100644 index 000000000..4f8418d84 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/linux-gcc-debug/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/linux-gcc-debug/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/linux-gcc-debug/presubmit.cfg b/3rdparty/spirv-tools/kokoro/linux-gcc-debug/presubmit.cfg new file mode 100644 index 000000000..2d9fe5c99 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/linux-gcc-debug/presubmit.cfg @@ -0,0 +1,17 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/linux-gcc-debug/build.sh" + diff --git a/3rdparty/spirv-tools/kokoro/linux-gcc-release/build.sh b/3rdparty/spirv-tools/kokoro/linux-gcc-release/build.sh new file mode 100644 index 000000000..3e97d8d3b --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/linux-gcc-release/build.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Linux Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +SCRIPT_DIR=`dirname "$BASH_SOURCE"` +source $SCRIPT_DIR/../scripts/linux/build.sh RELEASE gcc diff --git a/3rdparty/spirv-tools/kokoro/linux-gcc-release/continuous.cfg b/3rdparty/spirv-tools/kokoro/linux-gcc-release/continuous.cfg new file mode 100644 index 000000000..41a0024e7 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/linux-gcc-release/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/linux-gcc-release/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/linux-gcc-release/presubmit.cfg b/3rdparty/spirv-tools/kokoro/linux-gcc-release/presubmit.cfg new file mode 100644 index 000000000..c249a5ab0 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/linux-gcc-release/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/linux-gcc-release/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/macos-clang-debug/build.sh b/3rdparty/spirv-tools/kokoro/macos-clang-debug/build.sh new file mode 100644 index 000000000..8d9a062f6 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/macos-clang-debug/build.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MacOS Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +SCRIPT_DIR=`dirname "$BASH_SOURCE"` +source $SCRIPT_DIR/../scripts/macos/build.sh Debug + diff --git a/3rdparty/spirv-tools/kokoro/macos-clang-debug/continuous.cfg b/3rdparty/spirv-tools/kokoro/macos-clang-debug/continuous.cfg new file mode 100644 index 000000000..84aaa5c25 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/macos-clang-debug/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/macos-clang-debug/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/macos-clang-debug/presubmit.cfg b/3rdparty/spirv-tools/kokoro/macos-clang-debug/presubmit.cfg new file mode 100644 index 000000000..1d2f60da9 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/macos-clang-debug/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/macos-clang-debug/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/macos-clang-release/build.sh b/3rdparty/spirv-tools/kokoro/macos-clang-release/build.sh new file mode 100644 index 000000000..ccc8b16aa --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/macos-clang-release/build.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MacOS Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +SCRIPT_DIR=`dirname "$BASH_SOURCE"` +source $SCRIPT_DIR/../scripts/macos/build.sh RelWithDebInfo + diff --git a/3rdparty/spirv-tools/kokoro/macos-clang-release/continuous.cfg b/3rdparty/spirv-tools/kokoro/macos-clang-release/continuous.cfg new file mode 100644 index 000000000..a8e23a71a --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/macos-clang-release/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/macos-clang-release/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/macos-clang-release/presubmit.cfg b/3rdparty/spirv-tools/kokoro/macos-clang-release/presubmit.cfg new file mode 100644 index 000000000..dbaa266cc --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/macos-clang-release/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/macos-clang-release/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/ndk-build/build.sh b/3rdparty/spirv-tools/kokoro/ndk-build/build.sh new file mode 100644 index 000000000..d51f071ea --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/ndk-build/build.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Linux Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +BUILD_ROOT=$PWD +SRC=$PWD/github/SPIRV-Tools + +# Get NINJA. +wget -q https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip +unzip -q ninja-linux.zip +export PATH="$PWD:$PATH" + +# NDK Path +export ANDROID_NDK=/opt/android-ndk-r15c + +# Get the dependencies. +cd $SRC +git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers +git clone --depth=1 https://github.com/google/googletest external/googletest +git clone --depth=1 https://github.com/google/effcee external/effcee +git clone --depth=1 https://github.com/google/re2 external/re2 + +mkdir build && cd $SRC/build +mkdir libs +mkdir app + +# Invoke the build. +BUILD_SHA=${KOKORO_GITHUB_COMMIT:-$KOKORO_GITHUB_PULL_REQUEST_COMMIT} +echo $(date): Starting ndk-build ... +$ANDROID_NDK/ndk-build \ + -C $SRC/android_test \ + NDK_PROJECT_PATH=. \ + NDK_LIBS_OUT=./libs \ + NDK_APP_OUT=./app \ + -j8 + +echo $(date): ndk-build completed. + diff --git a/3rdparty/spirv-tools/kokoro/ndk-build/continuous.cfg b/3rdparty/spirv-tools/kokoro/ndk-build/continuous.cfg new file mode 100644 index 000000000..b908a4814 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/ndk-build/continuous.cfg @@ -0,0 +1,17 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +# +build_file: "SPIRV-Tools/kokoro/ndk-build/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/ndk-build/presubmit.cfg b/3rdparty/spirv-tools/kokoro/ndk-build/presubmit.cfg new file mode 100644 index 000000000..3c1be4bf7 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/ndk-build/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/ndk-build/build.sh" diff --git a/3rdparty/spirv-tools/kokoro/scripts/linux/build.sh b/3rdparty/spirv-tools/kokoro/scripts/linux/build.sh new file mode 100644 index 000000000..d457539d4 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/scripts/linux/build.sh @@ -0,0 +1,100 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Linux Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +BUILD_ROOT=$PWD +SRC=$PWD/github/SPIRV-Tools +CONFIG=$1 +COMPILER=$2 + +SKIP_TESTS="False" +BUILD_TYPE="Debug" + +CMAKE_C_CXX_COMPILER="" +if [ $COMPILER = "clang" ] +then + sudo ln -s /usr/bin/clang-3.8 /usr/bin/clang + sudo ln -s /usr/bin/clang++-3.8 /usr/bin/clang++ + CMAKE_C_CXX_COMPILER="-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++" +fi + +# Possible configurations are: +# ASAN, COVERAGE, RELEASE, DEBUG, DEBUG_EXCEPTION, RELEASE_MINGW + +if [ $CONFIG = "RELEASE" ] || [ $CONFIG = "RELEASE_MINGW" ] +then + BUILD_TYPE="RelWithDebInfo" +fi + +ADDITIONAL_CMAKE_FLAGS="" +if [ $CONFIG = "ASAN" ] +then + ADDITIONAL_CMAKE_FLAGS="-DCMAKE_CXX_FLAGS=-fsanitize=address -DCMAKE_C_FLAGS=-fsanitize=address" + export ASAN_SYMBOLIZER_PATH=/usr/bin/llvm-symbolizer-3.4 +elif [ $CONFIG = "COVERAGE" ] +then + ADDITIONAL_CMAKE_FLAGS="-DENABLE_CODE_COVERAGE=ON" + SKIP_TESTS="True" +elif [ $CONFIG = "DEBUG_EXCEPTION" ] +then + ADDITIONAL_CMAKE_FLAGS="-DDISABLE_EXCEPTIONS=ON -DDISABLE_RTTI=ON" +elif [ $CONFIG = "RELEASE_MINGW" ] +then + ADDITIONAL_CMAKE_FLAGS="-Dgtest_disable_pthreads=ON -DCMAKE_TOOLCHAIN_FILE=$SRC/cmake/linux-mingw-toolchain.cmake" + SKIP_TESTS="True" +fi + +# Get NINJA. +wget -q https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip +unzip -q ninja-linux.zip +export PATH="$PWD:$PATH" + +cd $SRC +git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers +git clone --depth=1 https://github.com/google/googletest external/googletest +git clone --depth=1 https://github.com/google/effcee external/effcee +git clone --depth=1 https://github.com/google/re2 external/re2 + +mkdir build && cd $SRC/build + +# Invoke the build. +BUILD_SHA=${KOKORO_GITHUB_COMMIT:-$KOKORO_GITHUB_PULL_REQUEST_COMMIT} +echo $(date): Starting build... +cmake -GNinja -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DCMAKE_INSTALL_PREFIX=install -DRE2_BUILD_TESTING=OFF $ADDITIONAL_CMAKE_FLAGS $CMAKE_C_CXX_COMPILER .. + +echo $(date): Build everything... +ninja +echo $(date): Build completed. + +if [ $CONFIG = "COVERAGE" ] +then + echo $(date): Check coverage... + ninja report-coverage + echo $(date): Check coverage completed. +fi + +echo $(date): Starting ctest... +if [ $SKIP_TESTS = "False" ] +then + ctest -j4 --output-on-failure --timeout 300 +fi +echo $(date): ctest completed. + diff --git a/3rdparty/spirv-tools/kokoro/scripts/macos/build.sh b/3rdparty/spirv-tools/kokoro/scripts/macos/build.sh new file mode 100644 index 000000000..a7f0453fe --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/scripts/macos/build.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MacOS Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +BUILD_ROOT=$PWD +SRC=$PWD/github/SPIRV-Tools +BUILD_TYPE=$1 + +# Get NINJA. +wget -q https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-mac.zip +unzip -q ninja-mac.zip +chmod +x ninja +export PATH="$PWD:$PATH" + +cd $SRC +git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers +git clone --depth=1 https://github.com/google/googletest external/googletest +git clone --depth=1 https://github.com/google/effcee external/effcee +git clone --depth=1 https://github.com/google/re2 external/re2 + +mkdir build && cd $SRC/build + +# Invoke the build. +BUILD_SHA=${KOKORO_GITHUB_COMMIT:-$KOKORO_GITHUB_PULL_REQUEST_COMMIT} +echo $(date): Starting build... +cmake -GNinja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=$BUILD_TYPE .. + +echo $(date): Build everything... +ninja +echo $(date): Build completed. + +echo $(date): Starting ctest... +ctest -j4 --output-on-failure --timeout 300 +echo $(date): ctest completed. + diff --git a/3rdparty/spirv-tools/kokoro/scripts/windows/build.bat b/3rdparty/spirv-tools/kokoro/scripts/windows/build.bat new file mode 100644 index 000000000..a2472fb4f --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/scripts/windows/build.bat @@ -0,0 +1,90 @@ +:: Copyright (c) 2018 Google LLC. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: +:: Windows Build Script. + +@echo on + +set BUILD_ROOT=%cd% +set SRC=%cd%\github\SPIRV-Tools +set BUILD_TYPE=%1 +set VS_VERSION=%2 + +:: Force usage of python 2.7 rather than 3.6 +set PATH=C:\python27;%PATH% + +cd %SRC% +git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers +git clone --depth=1 https://github.com/google/googletest external/googletest +git clone --depth=1 https://github.com/google/effcee external/effcee +git clone --depth=1 https://github.com/google/re2 external/re2 + +:: ######################################### +:: set up msvc build env +:: ######################################### +if %VS_VERSION% == 2017 ( + call "C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsall.bat" x64 + echo "Using VS 2017..." +) else if %VS_VERSION% == 2015 ( + call "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\vcvarsall.bat" x64 + echo "Using VS 2015..." +) else if %VS_VERSION% == 2013 ( + call "C:\Program Files (x86)\Microsoft Visual Studio 12.0\VC\vcvarsall.bat" x64 + echo "Using VS 2013..." +) + +cd %SRC% +mkdir build +cd build + +:: ######################################### +:: Start building. +:: ######################################### +echo "Starting build... %DATE% %TIME%" +if "%KOKORO_GITHUB_COMMIT%." == "." ( + set BUILD_SHA=%KOKORO_GITHUB_PULL_REQUEST_COMMIT% +) else ( + set BUILD_SHA=%KOKORO_GITHUB_COMMIT% +) + +:: Skip building tests for VS2013 +if %VS_VERSION% == 2013 ( + cmake -GNinja -DSPIRV_SKIP_TESTS=ON -DSPIRV_BUILD_COMPRESSION=ON -DCMAKE_BUILD_TYPE=%BUILD_TYPE% -DCMAKE_INSTALL_PREFIX=install -DRE2_BUILD_TESTING=OFF -DCMAKE_C_COMPILER=cl.exe -DCMAKE_CXX_COMPILER=cl.exe .. +) else ( + cmake -GNinja -DSPIRV_BUILD_COMPRESSION=ON -DCMAKE_BUILD_TYPE=%BUILD_TYPE% -DCMAKE_INSTALL_PREFIX=install -DRE2_BUILD_TESTING=OFF -DCMAKE_C_COMPILER=cl.exe -DCMAKE_CXX_COMPILER=cl.exe .. +) + +if %ERRORLEVEL% GEQ 1 exit /b %ERRORLEVEL% + +echo "Build everything... %DATE% %TIME%" +ninja +if %ERRORLEVEL% GEQ 1 exit /b %ERRORLEVEL% +echo "Build Completed %DATE% %TIME%" + +:: ################################################ +:: Run the tests (We no longer run tests on VS2013) +:: ################################################ +if NOT %VS_VERSION% == 2013 ( + echo "Running Tests... %DATE% %TIME%" + ctest -C %BUILD_TYPE% --output-on-failure --timeout 300 + if %ERRORLEVEL% GEQ 1 exit /b %ERRORLEVEL% + echo "Tests Completed %DATE% %TIME%" +) + +:: Clean up some directories. +rm -rf %SRC%\build +rm -rf %SRC%\external + +exit /b %ERRORLEVEL% + diff --git a/3rdparty/spirv-tools/kokoro/windows-msvc-2013-release/build.bat b/3rdparty/spirv-tools/kokoro/windows-msvc-2013-release/build.bat new file mode 100644 index 000000000..e77172afc --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/windows-msvc-2013-release/build.bat @@ -0,0 +1,24 @@ +:: Copyright (c) 2018 Google LLC. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: +:: Windows Build Script. + +@echo on + +:: Find out the directory of the common build script. +set SCRIPT_DIR=%~dp0 + +:: Call with correct parameter +call %SCRIPT_DIR%\..\scripts\windows\build.bat RelWithDebInfo 2013 + diff --git a/3rdparty/spirv-tools/kokoro/windows-msvc-2013-release/continuous.cfg b/3rdparty/spirv-tools/kokoro/windows-msvc-2013-release/continuous.cfg new file mode 100644 index 000000000..5dfcba63b --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/windows-msvc-2013-release/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2013-release/build.bat" diff --git a/3rdparty/spirv-tools/kokoro/windows-msvc-2013-release/presubmit.cfg b/3rdparty/spirv-tools/kokoro/windows-msvc-2013-release/presubmit.cfg new file mode 100644 index 000000000..7d3b23822 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/windows-msvc-2013-release/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2013-release/build.bat" diff --git a/3rdparty/spirv-tools/kokoro/windows-msvc-2015-release/build.bat b/3rdparty/spirv-tools/kokoro/windows-msvc-2015-release/build.bat new file mode 100644 index 000000000..c0e4bd317 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/windows-msvc-2015-release/build.bat @@ -0,0 +1,24 @@ +:: Copyright (c) 2018 Google LLC. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: +:: Windows Build Script. + +@echo on + +:: Find out the directory of the common build script. +set SCRIPT_DIR=%~dp0 + +:: Call with correct parameter +call %SCRIPT_DIR%\..\scripts\windows\build.bat RelWithDebInfo 2015 + diff --git a/3rdparty/spirv-tools/kokoro/windows-msvc-2015-release/continuous.cfg b/3rdparty/spirv-tools/kokoro/windows-msvc-2015-release/continuous.cfg new file mode 100644 index 000000000..3e47e5268 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/windows-msvc-2015-release/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2015-release/build.bat" diff --git a/3rdparty/spirv-tools/kokoro/windows-msvc-2015-release/presubmit.cfg b/3rdparty/spirv-tools/kokoro/windows-msvc-2015-release/presubmit.cfg new file mode 100644 index 000000000..85a162593 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/windows-msvc-2015-release/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2015-release/build.bat" diff --git a/3rdparty/spirv-tools/kokoro/windows-msvc-2017-debug/build.bat b/3rdparty/spirv-tools/kokoro/windows-msvc-2017-debug/build.bat new file mode 100644 index 000000000..25783a9e5 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/windows-msvc-2017-debug/build.bat @@ -0,0 +1,23 @@ +:: Copyright (c) 2018 Google LLC. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: +:: Windows Build Script. + +@echo on + +:: Find out the directory of the common build script. +set SCRIPT_DIR=%~dp0 + +:: Call with correct parameter +call %SCRIPT_DIR%\..\scripts\windows\build.bat Debug 2017 diff --git a/3rdparty/spirv-tools/kokoro/windows-msvc-2017-debug/continuous.cfg b/3rdparty/spirv-tools/kokoro/windows-msvc-2017-debug/continuous.cfg new file mode 100644 index 000000000..b842c30f1 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/windows-msvc-2017-debug/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2017-debug/build.bat" diff --git a/3rdparty/spirv-tools/kokoro/windows-msvc-2017-debug/presubmit.cfg b/3rdparty/spirv-tools/kokoro/windows-msvc-2017-debug/presubmit.cfg new file mode 100644 index 000000000..a7a553aee --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/windows-msvc-2017-debug/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2017-debug/build.bat" diff --git a/3rdparty/spirv-tools/kokoro/windows-msvc-2017-release/build.bat b/3rdparty/spirv-tools/kokoro/windows-msvc-2017-release/build.bat new file mode 100644 index 000000000..899fcbcfb --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/windows-msvc-2017-release/build.bat @@ -0,0 +1,24 @@ +:: Copyright (c) 2018 Google LLC. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: +:: Windows Build Script. + +@echo on + +:: Find out the directory of the common build script. +set SCRIPT_DIR=%~dp0 + +:: Call with correct parameter +call %SCRIPT_DIR%\..\scripts\windows\build.bat RelWithDebInfo 2017 + diff --git a/3rdparty/spirv-tools/kokoro/windows-msvc-2017-release/continuous.cfg b/3rdparty/spirv-tools/kokoro/windows-msvc-2017-release/continuous.cfg new file mode 100644 index 000000000..7b8c2ff2b --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/windows-msvc-2017-release/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2017-release/build.bat" diff --git a/3rdparty/spirv-tools/kokoro/windows-msvc-2017-release/presubmit.cfg b/3rdparty/spirv-tools/kokoro/windows-msvc-2017-release/presubmit.cfg new file mode 100644 index 000000000..5efd42927 --- /dev/null +++ b/3rdparty/spirv-tools/kokoro/windows-msvc-2017-release/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2017-release/build.bat" diff --git a/3rdparty/spirv-tools/source/CMakeLists.txt b/3rdparty/spirv-tools/source/CMakeLists.txt index e9444750a..4df5de3ad 100644 --- a/3rdparty/spirv-tools/source/CMakeLists.txt +++ b/3rdparty/spirv-tools/source/CMakeLists.txt @@ -217,9 +217,11 @@ set(SPIRV_SOURCES ${spirv-tools_SOURCE_DIR}/include/spirv-tools/libspirv.h ${CMAKE_CURRENT_SOURCE_DIR}/util/bitutils.h - ${CMAKE_CURRENT_SOURCE_DIR}/util/bit_stream.h + ${CMAKE_CURRENT_SOURCE_DIR}/util/bit_vector.h ${CMAKE_CURRENT_SOURCE_DIR}/util/hex_float.h + ${CMAKE_CURRENT_SOURCE_DIR}/util/make_unique.h ${CMAKE_CURRENT_SOURCE_DIR}/util/parse_number.h + ${CMAKE_CURRENT_SOURCE_DIR}/util/small_vector.h ${CMAKE_CURRENT_SOURCE_DIR}/util/string_utils.h ${CMAKE_CURRENT_SOURCE_DIR}/util/timer.h ${CMAKE_CURRENT_SOURCE_DIR}/assembly_grammar.h @@ -250,9 +252,9 @@ set(SPIRV_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/table.h ${CMAKE_CURRENT_SOURCE_DIR}/text.h ${CMAKE_CURRENT_SOURCE_DIR}/text_handler.h - ${CMAKE_CURRENT_SOURCE_DIR}/validate.h + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate.h - ${CMAKE_CURRENT_SOURCE_DIR}/util/bit_stream.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/util/bit_vector.cpp ${CMAKE_CURRENT_SOURCE_DIR}/util/parse_number.cpp ${CMAKE_CURRENT_SOURCE_DIR}/util/string_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/assembly_grammar.cpp @@ -264,7 +266,6 @@ set(SPIRV_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/extensions.cpp ${CMAKE_CURRENT_SOURCE_DIR}/id_descriptor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libspirv.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/message.cpp ${CMAKE_CURRENT_SOURCE_DIR}/name_mapper.cpp ${CMAKE_CURRENT_SOURCE_DIR}/opcode.cpp ${CMAKE_CURRENT_SOURCE_DIR}/operand.cpp @@ -272,35 +273,43 @@ set(SPIRV_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/print.cpp ${CMAKE_CURRENT_SOURCE_DIR}/software_version.cpp ${CMAKE_CURRENT_SOURCE_DIR}/spirv_endian.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/spirv_stats.cpp ${CMAKE_CURRENT_SOURCE_DIR}/spirv_target_env.cpp ${CMAKE_CURRENT_SOURCE_DIR}/spirv_validator_options.cpp ${CMAKE_CURRENT_SOURCE_DIR}/table.cpp ${CMAKE_CURRENT_SOURCE_DIR}/text.cpp ${CMAKE_CURRENT_SOURCE_DIR}/text_handler.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_adjacency.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_arithmetics.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_atomics.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_barriers.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_bitwise.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_builtins.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_capability.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_cfg.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_composites.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_conversion.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_datarules.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_decorations.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_derivatives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_ext_inst.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_id.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_image.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_instruction.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_layout.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_literals.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_logicals.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/validate_type_unique.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_adjacency.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_annotation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_arithmetics.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_atomics.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_barriers.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_bitwise.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_builtins.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_capability.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_cfg.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_composites.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_constants.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_conversion.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_datarules.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_debug.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_decorations.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_derivatives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_ext_inst.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_execution_limitations.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_function.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_id.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_image.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_interfaces.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_instruction.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_layout.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_literals.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_logicals.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_memory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_mode_setting.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_non_uniform.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_type.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/decoration.h ${CMAKE_CURRENT_SOURCE_DIR}/val/basic_block.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/construct.cpp diff --git a/3rdparty/spirv-tools/source/assembly_grammar.cpp b/3rdparty/spirv-tools/source/assembly_grammar.cpp index 2fde64053..4d98e3dab 100644 --- a/3rdparty/spirv-tools/source/assembly_grammar.cpp +++ b/3rdparty/spirv-tools/source/assembly_grammar.cpp @@ -12,17 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "assembly_grammar.h" +#include "source/assembly_grammar.h" #include #include #include -#include "ext_inst.h" -#include "opcode.h" -#include "operand.h" -#include "table.h" +#include "source/ext_inst.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/table.h" +namespace spvtools { namespace { /// @brief Parses a mask expression string for the given operand type. @@ -164,9 +165,7 @@ static_assert(59 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes const size_t kNumOpSpecConstantOpcodes = sizeof(kOpSpecConstantOpcodes) / sizeof(kOpSpecConstantOpcodes[0]); -} // anonymous namespace - -namespace libspirv { +} // namespace bool AssemblyGrammar::isValid() const { return operandTable_ && opcodeTable_ && extInstTable_; @@ -260,4 +259,5 @@ void AssemblyGrammar::pushOperandTypesForMask( spv_operand_pattern_t* pattern) const { spvPushOperandTypesForMask(target_env_, operandTable_, type, mask, pattern); } -} // namespace libspirv + +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/assembly_grammar.h b/3rdparty/spirv-tools/source/assembly_grammar.h index 6837a0b6e..17c2bd3ba 100644 --- a/3rdparty/spirv-tools/source/assembly_grammar.h +++ b/3rdparty/spirv-tools/source/assembly_grammar.h @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_ASSEMBLY_GRAMMAR_H_ -#define LIBSPIRV_ASSEMBLY_GRAMMAR_H_ +#ifndef SOURCE_ASSEMBLY_GRAMMAR_H_ +#define SOURCE_ASSEMBLY_GRAMMAR_H_ -#include "enum_set.h" -#include "latest_version_spirv_header.h" -#include "operand.h" +#include "source/enum_set.h" +#include "source/latest_version_spirv_header.h" +#include "source/operand.h" +#include "source/table.h" #include "spirv-tools/libspirv.h" -#include "table.h" -namespace libspirv { +namespace spvtools { // Encapsulates the grammar to use for SPIR-V assembly. // Contains methods to query for valid instructions and operands. @@ -132,6 +132,7 @@ class AssemblyGrammar { const spv_opcode_table opcodeTable_; const spv_ext_inst_table extInstTable_; }; -} // namespace libspirv -#endif // LIBSPIRV_ASSEMBLY_GRAMMAR_H_ +} // namespace spvtools + +#endif // SOURCE_ASSEMBLY_GRAMMAR_H_ diff --git a/3rdparty/spirv-tools/source/binary.cpp b/3rdparty/spirv-tools/source/binary.cpp index 7ac5765ee..6604d8094 100644 --- a/3rdparty/spirv-tools/source/binary.cpp +++ b/3rdparty/spirv-tools/source/binary.cpp @@ -12,24 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "binary.h" +#include "source/binary.h" #include #include #include #include #include +#include #include #include -#include "assembly_grammar.h" -#include "diagnostic.h" -#include "ext_inst.h" -#include "latest_version_spirv_header.h" -#include "opcode.h" -#include "operand.h" -#include "spirv_constant.h" -#include "spirv_endian.h" +#include "source/assembly_grammar.h" +#include "source/diagnostic.h" +#include "source/ext_inst.h" +#include "source/latest_version_spirv_header.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/spirv_constant.h" +#include "source/spirv_endian.h" spv_result_t spvBinaryHeaderGet(const spv_const_binary binary, const spv_endianness_t endian, @@ -121,12 +122,13 @@ class Parser { // the input stream, and for the given error code. Any data written to the // returned object will be propagated to the current parse's diagnostic // object. - libspirv::DiagnosticStream diagnostic(spv_result_t error) { - return libspirv::DiagnosticStream({0, 0, _.word_index}, consumer_, error); + spvtools::DiagnosticStream diagnostic(spv_result_t error) { + return spvtools::DiagnosticStream({0, 0, _.word_index}, consumer_, "", + error); } // Returns a diagnostic stream object with the default parse error code. - libspirv::DiagnosticStream diagnostic() { + spvtools::DiagnosticStream diagnostic() { // The default failure for parsing is invalid binary. return diagnostic(SPV_ERROR_INVALID_BINARY); } @@ -156,7 +158,7 @@ class Parser { // Data members - const libspirv::AssemblyGrammar grammar_; // SPIR-V syntax utility. + const spvtools::AssemblyGrammar grammar_; // SPIR-V syntax utility. const spvtools::MessageConsumer& consumer_; // Message consumer callback. void* const user_data_; // Context for the callbacks const spv_parsed_header_fn_t parsed_header_fn_; // Parsed header callback @@ -766,7 +768,7 @@ spv_result_t spvBinaryParse(const spv_const_context context, void* user_data, spv_context_t hijack_context = *context; if (diagnostic) { *diagnostic = nullptr; - libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic); + spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic); } Parser parser(&hijack_context, user_data, parsed_header, parsed_instruction); return parser.parse(code, num_words, diagnostic); diff --git a/3rdparty/spirv-tools/source/binary.h b/3rdparty/spirv-tools/source/binary.h index f6237e32b..66d24c7e4 100644 --- a/3rdparty/spirv-tools/source/binary.h +++ b/3rdparty/spirv-tools/source/binary.h @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_BINARY_H_ -#define LIBSPIRV_BINARY_H_ +#ifndef SOURCE_BINARY_H_ +#define SOURCE_BINARY_H_ +#include "source/spirv_definition.h" #include "spirv-tools/libspirv.h" -#include "spirv_definition.h" // Functions @@ -33,4 +33,4 @@ spv_result_t spvBinaryHeaderGet(const spv_const_binary binary, // replacement for C11's strnlen_s which might not exist in all environments. size_t spv_strnlen_s(const char* str, size_t strsz); -#endif // LIBSPIRV_BINARY_H_ +#endif // SOURCE_BINARY_H_ diff --git a/3rdparty/spirv-tools/source/cfa.h b/3rdparty/spirv-tools/source/cfa.h index 1022e3f2d..97ef398d6 100644 --- a/3rdparty/spirv-tools/source/cfa.h +++ b/3rdparty/spirv-tools/source/cfa.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SPVTOOLS_CFA_H_ -#define SPVTOOLS_CFA_H_ +#ifndef SOURCE_CFA_H_ +#define SOURCE_CFA_H_ #include #include @@ -25,14 +25,6 @@ #include #include -using std::find; -using std::function; -using std::get; -using std::pair; -using std::unordered_map; -using std::unordered_set; -using std::vector; - namespace spvtools { // Control Flow Analysis of control flow graphs of basic block nodes |BB|. @@ -111,8 +103,8 @@ class CFA { /// block /// without predecessors (such as the root node) is its own immediate /// dominator. - static vector> CalculateDominators( - const vector& postorder, get_blocks_func predecessor_func); + static std::vector> CalculateDominators( + const std::vector& postorder, get_blocks_func predecessor_func); // Computes a minimal set of root nodes required to traverse, in the forward // direction, the CFG represented by the given vector of blocks, and successor @@ -133,7 +125,8 @@ class CFA { }; template -bool CFA::FindInWorkList(const vector& work_list, uint32_t id) { +bool CFA::FindInWorkList(const std::vector& work_list, + uint32_t id) { for (const auto b : work_list) { if (b.block->id() == id) return true; } @@ -141,19 +134,19 @@ bool CFA::FindInWorkList(const vector& work_list, uint32_t id) { } template -void CFA::DepthFirstTraversal(const BB* entry, - get_blocks_func successor_func, - function preorder, - function postorder, - function backedge) { - unordered_set processed; +void CFA::DepthFirstTraversal( + const BB* entry, get_blocks_func successor_func, + std::function preorder, + std::function postorder, + std::function backedge) { + std::unordered_set processed; /// NOTE: work_list is the sequence of nodes from the root node to the node /// being processed in the traversal - vector work_list; + std::vector work_list; work_list.reserve(10); - work_list.push_back({entry, begin(*successor_func(entry))}); + work_list.push_back({entry, std::begin(*successor_func(entry))}); preorder(entry); processed.insert(entry->id()); @@ -171,7 +164,7 @@ void CFA::DepthFirstTraversal(const BB* entry, if (processed.count(child->id()) == 0) { preorder(child); work_list.emplace_back( - block_info{child, begin(*successor_func(child))}); + block_info{child, std::begin(*successor_func(child))}); processed.insert(child->id()); } } @@ -179,15 +172,15 @@ void CFA::DepthFirstTraversal(const BB* entry, } template -vector> CFA::CalculateDominators( - const vector& postorder, get_blocks_func predecessor_func) { +std::vector> CFA::CalculateDominators( + const std::vector& postorder, get_blocks_func predecessor_func) { struct block_detail { size_t dominator; ///< The index of blocks's dominator in post order array size_t postorder_index; ///< The index of the block in the post order array }; const size_t undefined_dom = postorder.size(); - unordered_map idoms; + std::unordered_map idoms; for (size_t i = 0; i < postorder.size(); i++) { idoms[postorder[i]] = {undefined_dom, i}; } @@ -197,14 +190,14 @@ vector> CFA::CalculateDominators( while (changed) { changed = false; for (auto b = postorder.rbegin() + 1; b != postorder.rend(); ++b) { - const vector& predecessors = *predecessor_func(*b); + const std::vector& predecessors = *predecessor_func(*b); // Find the first processed/reachable predecessor that is reachable // in the forward traversal. - auto res = find_if(begin(predecessors), end(predecessors), - [&idoms, undefined_dom](BB* pred) { - return idoms.count(pred) && - idoms[pred].dominator != undefined_dom; - }); + auto res = std::find_if(std::begin(predecessors), std::end(predecessors), + [&idoms, undefined_dom](BB* pred) { + return idoms.count(pred) && + idoms[pred].dominator != undefined_dom; + }); if (res == end(predecessors)) continue; const BB* idom = *res; size_t idom_idx = idoms[idom].postorder_index; @@ -237,13 +230,29 @@ vector> CFA::CalculateDominators( } } - vector> out; + std::vector> out; for (auto idom : idoms) { // NOTE: performing a const cast for convenient usage with // UpdateImmediateDominators - out.push_back({const_cast(get<0>(idom)), - const_cast(postorder[get<1>(idom).dominator])}); + out.push_back({const_cast(std::get<0>(idom)), + const_cast(postorder[std::get<1>(idom).dominator])}); } + + // Sort by postorder index to generate a deterministic ordering of edges. + std::sort( + out.begin(), out.end(), + [&idoms](const std::pair& lhs, + const std::pair& rhs) { + assert(lhs.first); + assert(lhs.second); + assert(rhs.first); + assert(rhs.second); + auto lhs_indices = std::make_pair(idoms[lhs.first].postorder_index, + idoms[lhs.second].postorder_index); + auto rhs_indices = std::make_pair(idoms[rhs.first].postorder_index, + idoms[rhs.second].postorder_index); + return lhs_indices < rhs_indices; + }); return out; } @@ -335,4 +344,4 @@ void CFA::ComputeAugmentedCFG( } // namespace spvtools -#endif // SPVTOOLS_CFA_H_ +#endif // SOURCE_CFA_H_ diff --git a/3rdparty/spirv-tools/source/comp/CMakeLists.txt b/3rdparty/spirv-tools/source/comp/CMakeLists.txt index ff52d5e1e..f65f9f670 100644 --- a/3rdparty/spirv-tools/source/comp/CMakeLists.txt +++ b/3rdparty/spirv-tools/source/comp/CMakeLists.txt @@ -13,7 +13,21 @@ # limitations under the License. if(SPIRV_BUILD_COMPRESSION) - add_library(SPIRV-Tools-comp markv_codec.cpp) + add_library(SPIRV-Tools-comp + bit_stream.cpp + bit_stream.h + huffman_codec.h + markv_codec.cpp + markv_codec.h + markv.cpp + markv.h + markv_decoder.cpp + markv_decoder.h + markv_encoder.cpp + markv_encoder.h + markv_logger.h + move_to_front.h + move_to_front.cpp) spvtools_default_compile_options(SPIRV-Tools-comp) target_include_directories(SPIRV-Tools-comp diff --git a/3rdparty/spirv-tools/source/util/bit_stream.cpp b/3rdparty/spirv-tools/source/comp/bit_stream.cpp similarity index 81% rename from 3rdparty/spirv-tools/source/util/bit_stream.cpp rename to 3rdparty/spirv-tools/source/comp/bit_stream.cpp index 77e2bc17d..a5769e03e 100644 --- a/3rdparty/spirv-tools/source/util/bit_stream.cpp +++ b/3rdparty/spirv-tools/source/comp/bit_stream.cpp @@ -18,10 +18,10 @@ #include #include -#include "util/bit_stream.h" - -namespace spvutils { +#include "source/comp/bit_stream.h" +namespace spvtools { +namespace comp { namespace { // Returns if the system is little-endian. Unfortunately only works during @@ -197,41 +197,6 @@ bool ReadVariableWidthSigned(BitReaderInterface* reader, T* val, } // namespace -size_t Log2U64(uint64_t val) { - size_t res = 0; - - if (val & 0xFFFFFFFF00000000) { - val >>= 32; - res |= 32; - } - - if (val & 0xFFFF0000) { - val >>= 16; - res |= 16; - } - - if (val & 0xFF00) { - val >>= 8; - res |= 8; - } - - if (val & 0xF0) { - val >>= 4; - res |= 4; - } - - if (val & 0xC) { - val >>= 2; - res |= 2; - } - - if (val & 0x2) { - res |= 1; - } - - return res; -} - void BitWriterInterface::WriteVariableWidthU64(uint64_t val, size_t chunk_length) { WriteVariableWidthUnsigned(this, val, chunk_length); @@ -247,41 +212,11 @@ void BitWriterInterface::WriteVariableWidthU16(uint16_t val, WriteVariableWidthUnsigned(this, val, chunk_length); } -void BitWriterInterface::WriteVariableWidthU8(uint8_t val, - size_t chunk_length) { - WriteVariableWidthUnsigned(this, val, chunk_length); -} - void BitWriterInterface::WriteVariableWidthS64(int64_t val, size_t chunk_length, size_t zigzag_exponent) { WriteVariableWidthSigned(this, val, chunk_length, zigzag_exponent); } -void BitWriterInterface::WriteVariableWidthS32(int32_t val, size_t chunk_length, - size_t zigzag_exponent) { - WriteVariableWidthSigned(this, val, chunk_length, zigzag_exponent); -} - -void BitWriterInterface::WriteVariableWidthS16(int16_t val, size_t chunk_length, - size_t zigzag_exponent) { - WriteVariableWidthSigned(this, val, chunk_length, zigzag_exponent); -} - -void BitWriterInterface::WriteVariableWidthS8(int8_t val, size_t chunk_length, - size_t zigzag_exponent) { - WriteVariableWidthSigned(this, val, chunk_length, zigzag_exponent); -} - -void BitWriterInterface::WriteFixedWidth(uint64_t val, uint64_t max_val) { - if (val > max_val) { - assert(0 && "WriteFixedWidth: value too wide"); - return; - } - - const size_t num_bits = 1 + Log2U64(max_val); - WriteBits(val, num_bits); -} - BitWriterWord64::BitWriterWord64(size_t reserve_bits) : end_(0) { buffer_.reserve(NumBitsToNumWords<64>(reserve_bits)); } @@ -340,36 +275,11 @@ bool BitReaderInterface::ReadVariableWidthU16(uint16_t* val, return ReadVariableWidthUnsigned(this, val, chunk_length); } -bool BitReaderInterface::ReadVariableWidthU8(uint8_t* val, - size_t chunk_length) { - return ReadVariableWidthUnsigned(this, val, chunk_length); -} - bool BitReaderInterface::ReadVariableWidthS64(int64_t* val, size_t chunk_length, size_t zigzag_exponent) { return ReadVariableWidthSigned(this, val, chunk_length, zigzag_exponent); } -bool BitReaderInterface::ReadVariableWidthS32(int32_t* val, size_t chunk_length, - size_t zigzag_exponent) { - return ReadVariableWidthSigned(this, val, chunk_length, zigzag_exponent); -} - -bool BitReaderInterface::ReadVariableWidthS16(int16_t* val, size_t chunk_length, - size_t zigzag_exponent) { - return ReadVariableWidthSigned(this, val, chunk_length, zigzag_exponent); -} - -bool BitReaderInterface::ReadVariableWidthS8(int8_t* val, size_t chunk_length, - size_t zigzag_exponent) { - return ReadVariableWidthSigned(this, val, chunk_length, zigzag_exponent); -} - -bool BitReaderInterface::ReadFixedWidth(uint64_t* val, uint64_t max_val) { - const size_t num_bits = 1 + Log2U64(max_val); - return ReadBits(val, num_bits) == num_bits; -} - BitReaderWord64::BitReaderWord64(std::vector&& buffer) : buffer_(std::move(buffer)), pos_(0) {} @@ -434,4 +344,5 @@ bool BitReaderWord64::OnlyZeroesLeft() const { return !remaining_bits; } -} // namespace spvutils +} // namespace comp +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/util/bit_stream.h b/3rdparty/spirv-tools/source/comp/bit_stream.h similarity index 60% rename from 3rdparty/spirv-tools/source/util/bit_stream.h rename to 3rdparty/spirv-tools/source/comp/bit_stream.h index 247ae3545..5f82344d6 100644 --- a/3rdparty/spirv-tools/source/util/bit_stream.h +++ b/3rdparty/spirv-tools/source/comp/bit_stream.h @@ -14,21 +14,22 @@ // Contains utils for reading, writing and debug printing bit streams. -#ifndef LIBSPIRV_UTIL_BIT_STREAM_H_ -#define LIBSPIRV_UTIL_BIT_STREAM_H_ +#ifndef SOURCE_COMP_BIT_STREAM_H_ +#define SOURCE_COMP_BIT_STREAM_H_ #include #include +#include #include +#include #include #include #include +#include #include -namespace spvutils { - -// Returns rounded down log2(val). log2(0) is considered 0. -size_t Log2U64(uint64_t val); +namespace spvtools { +namespace comp { // Terminology: // Bits - usually used for a uint64 word, first bit is the lowest. @@ -52,33 +53,6 @@ inline T GetLowerBits(T in, size_t num_bits) { return sizeof(T) * 8 == num_bits ? in : in & T((T(1) << num_bits) - T(1)); } -// Encodes signed integer as unsigned in zigzag order: -// 0 -> 0 -// -1 -> 1 -// 1 -> 2 -// -2 -> 3 -// 2 -> 4 -// Motivation: -1 is 0xFF...FF what doesn't work very well with -// WriteVariableWidth which prefers to have as many 0 bits as possible. -inline uint64_t EncodeZigZag(int64_t val) { return (val << 1) ^ (val >> 63); } - -// Decodes signed integer encoded with EncodeZigZag. -inline int64_t DecodeZigZag(uint64_t val) { - if (val & 1) { - // Negative. - // 1 -> -1 - // 3 -> -2 - // 5 -> -3 - return -1 - (val >> 1); - } else { - // Non-negative. - // 0 -> 0 - // 2 -> 1 - // 4 -> 2 - return val >> 1; - } -} - // Encodes signed integer as unsigned. This is a generalized version of // EncodeZigZag, designed to favor small positive numbers. // Values are transformed in blocks of 2^|block_exponent|. @@ -111,111 +85,24 @@ inline int64_t DecodeZigZag(uint64_t val, size_t block_exponent) { } } -// Converts |buffer| to a stream of '0' and '1'. -template -std::string BufferToStream(const std::vector& buffer) { - std::stringstream ss; - for (auto it = buffer.begin(); it != buffer.end(); ++it) { - std::string str = std::bitset(*it).to_string(); - // Strings generated by std::bitset::to_string are read right to left. - // Reversing to left to right. - std::reverse(str.begin(), str.end()); - ss << str; - } - return ss.str(); -} - -// Converts a left-to-right input string of '0' and '1' to a buffer of |T| -// words. -template -std::vector StreamToBuffer(std::string str) { - // The input string is left-to-right, the input argument of std::bitset needs - // to right-to-left. Instead of reversing tokens, reverse the entire string - // and iterate tokens from end to begin. - std::reverse(str.begin(), str.end()); - const int word_size = static_cast(sizeof(T) * 8); - const int str_length = static_cast(str.length()); - std::vector buffer; - buffer.reserve(NumBitsToNumWords(str.length())); - for (int index = str_length - word_size; index >= 0; index -= word_size) { - buffer.push_back(static_cast( - std::bitset(str, index, word_size).to_ullong())); - } - const size_t suffix_length = str.length() % word_size; - if (suffix_length != 0) { - buffer.push_back(static_cast( - std::bitset(str, 0, suffix_length).to_ullong())); - } - return buffer; -} - -// Adds '0' chars at the end of the string until the size is a multiple of N. -template -inline std::string PadToWord(std::string&& str) { - const size_t tail_length = str.size() % N; - if (tail_length != 0) str += std::string(N - tail_length, '0'); - return str; -} - -// Adds '0' chars at the end of the string until the size is a multiple of N. -template -inline std::string PadToWord(const std::string& str) { - return PadToWord(std::string(str)); -} - -// Converts a left-to-right stream of bits to std::bitset. -template -inline std::bitset StreamToBitset(std::string str) { - std::reverse(str.begin(), str.end()); - return std::bitset(str); -} - -// Converts first |num_bits| of std::bitset to a left-to-right stream of bits. -template -inline std::string BitsetToStream(const std::bitset& bits, - size_t num_bits = N) { - std::string str = bits.to_string().substr(N - num_bits); - std::reverse(str.begin(), str.end()); - return str; -} - -// Converts a left-to-right stream of bits to uint64. -inline uint64_t StreamToBits(std::string str) { - std::reverse(str.begin(), str.end()); - return std::bitset<64>(str).to_ullong(); -} - // Converts first |num_bits| stored in uint64 to a left-to-right stream of bits. inline std::string BitsToStream(uint64_t bits, size_t num_bits = 64) { std::bitset<64> bitset(bits); - return BitsetToStream(bitset, num_bits); + std::string str = bitset.to_string().substr(64 - num_bits); + std::reverse(str.begin(), str.end()); + return str; } // Base class for writing sequences of bits. class BitWriterInterface { public: - BitWriterInterface() {} - virtual ~BitWriterInterface() {} + BitWriterInterface() = default; + virtual ~BitWriterInterface() = default; // Writes lower |num_bits| in |bits| to the stream. // |num_bits| must be no greater than 64. virtual void WriteBits(uint64_t bits, size_t num_bits) = 0; - // Writes left-to-right string of '0' and '1' to stream. - // String length must be no greater than 64. - // Note: "01" will be writen as 0x2, not 0x1. The string doesn't represent - // numbers but a stream of bits in the order they come from encoder. - virtual void WriteStream(const std::string& bits) { - WriteBits(StreamToBits(bits), bits.length()); - } - - // Writes lower |num_bits| in |bits| to the stream. - // |num_bits| must be no greater than 64. - template - void WriteBitset(const std::bitset& bits, size_t num_bits = N) { - WriteBits(bits.to_ullong(), num_bits); - } - // Writes bits from value of type |T| to the stream. No encoding is done. // Always writes 8 * sizeof(T) bits. template @@ -235,27 +122,8 @@ class BitWriterInterface { void WriteVariableWidthU64(uint64_t val, size_t chunk_length); void WriteVariableWidthU32(uint32_t val, size_t chunk_length); void WriteVariableWidthU16(uint16_t val, size_t chunk_length); - void WriteVariableWidthU8(uint8_t val, size_t chunk_length); void WriteVariableWidthS64(int64_t val, size_t chunk_length, size_t zigzag_exponent); - void WriteVariableWidthS32(int32_t val, size_t chunk_length, - size_t zigzag_exponent); - void WriteVariableWidthS16(int16_t val, size_t chunk_length, - size_t zigzag_exponent); - void WriteVariableWidthS8(int8_t val, size_t chunk_length, - size_t zigzag_exponent); - - // Writes |val| using fixed bit width. Bit width is determined by |max_val|: - // max_val 0 -> bit width 1 - // max_val 1 -> bit width 1 - // max_val 2 -> bit width 2 - // max_val 3 -> bit width 2 - // max_val 4 -> bit width 3 - // max_val 5 -> bit width 3 - // max_val 8 -> bit width 4 - // max_val n -> bit width 1 + floor(log2(n)) - // |val| needs to be <= |max_val|. - void WriteFixedWidth(uint64_t val, uint64_t max_val); // Returns number of bits written. virtual size_t GetNumBits() const = 0; @@ -291,10 +159,6 @@ class BitWriterWord64 : public BitWriterInterface { return std::vector(GetData(), GetData() + GetDataSizeBytes()); } - // Returns written stream as std::string, padded with zeroes so that the - // length is a multiple of 64. - std::string GetStreamPadded64() const { return BufferToStream(buffer_); } - // Sets callback to emit bit sequences after every write. void SetCallback(std::function callback) { callback_ = callback; @@ -326,27 +190,6 @@ class BitReaderInterface { // Returns number of read bits. |num_bits| must be no greater than 64. virtual size_t ReadBits(uint64_t* bits, size_t num_bits) = 0; - // Reads |num_bits| from the stream, stores them in |bits|. - // Returns number of read bits. |num_bits| must be no greater than 64. - template - size_t ReadBitset(std::bitset* bits, size_t num_bits = N) { - uint64_t val = 0; - size_t num_read = ReadBits(&val, num_bits); - if (num_read) { - *bits = std::bitset(val); - } - return num_read; - } - - // Reads |num_bits| from the stream, returns string in left-to-right order. - // The length of the returned string may be less than |num_bits| if end was - // reached. - std::string ReadStream(size_t num_bits) { - uint64_t bits = 0; - size_t num_read = ReadBits(&bits, num_bits); - return BitsToStream(bits, num_read); - } - // Reads 8 * sizeof(T) bits and stores them in |val|. template bool ReadUnencoded(T* val) { @@ -381,19 +224,8 @@ class BitReaderInterface { bool ReadVariableWidthU64(uint64_t* val, size_t chunk_length); bool ReadVariableWidthU32(uint32_t* val, size_t chunk_length); bool ReadVariableWidthU16(uint16_t* val, size_t chunk_length); - bool ReadVariableWidthU8(uint8_t* val, size_t chunk_length); bool ReadVariableWidthS64(int64_t* val, size_t chunk_length, size_t zigzag_exponent); - bool ReadVariableWidthS32(int32_t* val, size_t chunk_length, - size_t zigzag_exponent); - bool ReadVariableWidthS16(int16_t* val, size_t chunk_length, - size_t zigzag_exponent); - bool ReadVariableWidthS8(int8_t* val, size_t chunk_length, - size_t zigzag_exponent); - - // Reads value written by WriteFixedWidth (|max_val| needs to be the same). - // Returns true on success, false if the bit stream ends prematurely. - bool ReadFixedWidth(uint64_t* val, uint64_t max_val); BitReaderInterface(const BitReaderInterface&) = delete; BitReaderInterface& operator=(const BitReaderInterface&) = delete; @@ -442,6 +274,7 @@ class BitReaderWord64 : public BitReaderInterface { std::function callback_; }; -} // namespace spvutils +} // namespace comp +} // namespace spvtools -#endif // LIBSPIRV_UTIL_BIT_STREAM_H_ +#endif // SOURCE_COMP_BIT_STREAM_H_ diff --git a/3rdparty/spirv-tools/source/util/huffman_codec.h b/3rdparty/spirv-tools/source/comp/huffman_codec.h similarity index 98% rename from 3rdparty/spirv-tools/source/util/huffman_codec.h rename to 3rdparty/spirv-tools/source/comp/huffman_codec.h index c2f7b1a98..166021614 100644 --- a/3rdparty/spirv-tools/source/util/huffman_codec.h +++ b/3rdparty/spirv-tools/source/comp/huffman_codec.h @@ -14,8 +14,8 @@ // Contains utils for reading, writing and debug printing bit streams. -#ifndef LIBSPIRV_UTIL_HUFFMAN_CODEC_H_ -#define LIBSPIRV_UTIL_HUFFMAN_CODEC_H_ +#ifndef SOURCE_COMP_HUFFMAN_CODEC_H_ +#define SOURCE_COMP_HUFFMAN_CODEC_H_ #include #include @@ -27,11 +27,14 @@ #include #include #include +#include #include #include +#include #include -namespace spvutils { +namespace spvtools { +namespace comp { // Used to generate and apply a Huffman coding scheme. // |Val| is the type of variable being encoded (for example a string or a @@ -380,6 +383,7 @@ class HuffmanCodec { uint32_t next_node_id_ = 1; }; -} // namespace spvutils +} // namespace comp +} // namespace spvtools -#endif // LIBSPIRV_UTIL_HUFFMAN_CODEC_H_ +#endif // SOURCE_COMP_HUFFMAN_CODEC_H_ diff --git a/3rdparty/spirv-tools/source/comp/markv.cpp b/3rdparty/spirv-tools/source/comp/markv.cpp new file mode 100644 index 000000000..736bc51ba --- /dev/null +++ b/3rdparty/spirv-tools/source/comp/markv.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/comp/markv.h" + +#include "source/comp/markv_decoder.h" +#include "source/comp/markv_encoder.h" + +namespace spvtools { +namespace comp { +namespace { + +spv_result_t EncodeHeader(void* user_data, spv_endianness_t endian, + uint32_t magic, uint32_t version, uint32_t generator, + uint32_t id_bound, uint32_t schema) { + MarkvEncoder* encoder = reinterpret_cast(user_data); + return encoder->EncodeHeader(endian, magic, version, generator, id_bound, + schema); +} + +spv_result_t EncodeInstruction(void* user_data, + const spv_parsed_instruction_t* inst) { + MarkvEncoder* encoder = reinterpret_cast(user_data); + return encoder->EncodeInstruction(*inst); +} + +} // namespace + +spv_result_t SpirvToMarkv( + spv_const_context context, const std::vector& spirv, + const MarkvCodecOptions& options, const MarkvModel& markv_model, + MessageConsumer message_consumer, MarkvLogConsumer log_consumer, + MarkvDebugConsumer debug_consumer, std::vector* markv) { + spv_context_t hijack_context = *context; + SetContextMessageConsumer(&hijack_context, message_consumer); + + spv_validator_options validator_options = + MarkvDecoder::GetValidatorOptions(options); + if (validator_options) { + spv_const_binary_t spirv_binary = {spirv.data(), spirv.size()}; + const spv_result_t result = spvValidateWithOptions( + &hijack_context, validator_options, &spirv_binary, nullptr); + if (result != SPV_SUCCESS) return result; + } + + MarkvEncoder encoder(&hijack_context, options, &markv_model); + + spv_position_t position = {}; + if (log_consumer || debug_consumer) { + encoder.CreateLogger(log_consumer, debug_consumer); + + spv_text text = nullptr; + if (spvBinaryToText(&hijack_context, spirv.data(), spirv.size(), + SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, + nullptr) != SPV_SUCCESS) { + return DiagnosticStream(position, hijack_context.consumer, "", + SPV_ERROR_INVALID_BINARY) + << "Failed to disassemble SPIR-V binary."; + } + assert(text); + encoder.SetDisassembly(std::string(text->str, text->length)); + spvTextDestroy(text); + } + + if (spvBinaryParse(&hijack_context, &encoder, spirv.data(), spirv.size(), + EncodeHeader, EncodeInstruction, nullptr) != SPV_SUCCESS) { + return DiagnosticStream(position, hijack_context.consumer, "", + SPV_ERROR_INVALID_BINARY) + << "Unable to encode to MARK-V."; + } + + *markv = encoder.GetMarkvBinary(); + return SPV_SUCCESS; +} + +spv_result_t MarkvToSpirv( + spv_const_context context, const std::vector& markv, + const MarkvCodecOptions& options, const MarkvModel& markv_model, + MessageConsumer message_consumer, MarkvLogConsumer log_consumer, + MarkvDebugConsumer debug_consumer, std::vector* spirv) { + spv_position_t position = {}; + spv_context_t hijack_context = *context; + SetContextMessageConsumer(&hijack_context, message_consumer); + + MarkvDecoder decoder(&hijack_context, markv, options, &markv_model); + + if (log_consumer || debug_consumer) + decoder.CreateLogger(log_consumer, debug_consumer); + + if (decoder.DecodeModule(spirv) != SPV_SUCCESS) { + return DiagnosticStream(position, hijack_context.consumer, "", + SPV_ERROR_INVALID_BINARY) + << "Unable to decode MARK-V."; + } + + assert(!spirv->empty()); + return SPV_SUCCESS; +} + +} // namespace comp +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/comp/markv.h b/3rdparty/spirv-tools/source/comp/markv.h index 288e68085..587086f91 100644 --- a/3rdparty/spirv-tools/source/comp/markv.h +++ b/3rdparty/spirv-tools/source/comp/markv.h @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Google Inc. +// Copyright (c) 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,16 +18,15 @@ // make it more similar to other compressed SPIR-V files to further improve // compression of the dataset. -#ifndef SPIRV_TOOLS_MARKV_HPP_ -#define SPIRV_TOOLS_MARKV_HPP_ +#ifndef SOURCE_COMP_MARKV_H_ +#define SOURCE_COMP_MARKV_H_ -#include -#include - -#include "markv_model.h" #include "spirv-tools/libspirv.hpp" namespace spvtools { +namespace comp { + +class MarkvModel; struct MarkvCodecOptions { bool validate_spirv_binary = false; @@ -69,6 +68,7 @@ spv_result_t MarkvToSpirv( MessageConsumer message_consumer, MarkvLogConsumer log_consumer, MarkvDebugConsumer debug_consumer, std::vector* spirv); +} // namespace comp } // namespace spvtools -#endif // SPIRV_TOOLS_MARKV_HPP_ +#endif // SOURCE_COMP_MARKV_H_ diff --git a/3rdparty/spirv-tools/source/comp/markv_codec.cpp b/3rdparty/spirv-tools/source/comp/markv_codec.cpp index cbc25ab02..ae3ce79f2 100644 --- a/3rdparty/spirv-tools/source/comp/markv_codec.cpp +++ b/3rdparty/spirv-tools/source/comp/markv_codec.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Google Inc. +// Copyright (c) 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,148 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Contains -// - SPIR-V to MARK-V encoder -// - MARK-V to SPIR-V decoder -// // MARK-V is a compression format for SPIR-V binaries. It strips away -// non-essential information (such as result ids which can be regenerated) and -// uses various bit reduction techiniques to reduce the size of the binary. +// non-essential information (such as result IDs which can be regenerated) and +// uses various bit reduction techniques to reduce the size of the binary. -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "source/comp/markv_codec.h" -#include "latest_version_glsl_std_450_header.h" -#include "latest_version_opencl_std_header.h" -#include "latest_version_spirv_header.h" - -#include "binary.h" -#include "diagnostic.h" -#include "enum_string_mapping.h" -#include "ext_inst.h" -#include "extensions.h" -#include "id_descriptor.h" -#include "instruction.h" -#include "markv.h" -#include "markv_model.h" -#include "opcode.h" -#include "operand.h" -#include "spirv-tools/libspirv.h" -#include "spirv_endian.h" -#include "spirv_validator_options.h" -#include "util/bit_stream.h" -#include "util/huffman_codec.h" -#include "util/move_to_front.h" -#include "util/parse_number.h" -#include "val/instruction.h" -#include "val/validation_state.h" -#include "validate.h" - -using libspirv::DiagnosticStream; -using libspirv::IdDescriptorCollection; -using libspirv::Instruction; -using libspirv::ValidationState_t; -using spvutils::BitReaderWord64; -using spvutils::BitWriterWord64; -using spvutils::HuffmanCodec; -using MoveToFront = spvutils::MoveToFront; -using MultiMoveToFront = spvutils::MultiMoveToFront; +#include "source/comp/markv_logger.h" +#include "source/latest_version_glsl_std_450_header.h" +#include "source/latest_version_opencl_std_header.h" +#include "source/opcode.h" +#include "source/util/make_unique.h" namespace spvtools { - +namespace comp { namespace { -const uint32_t kSpirvMagicNumber = SpvMagicNumber; -const uint32_t kMarkvMagicNumber = 0x07230303; - -// Handles for move-to-front sequences. Enums which end with "Begin" define -// handle spaces which start at that value and span 16 or 32 bit wide. -enum : uint64_t { - kMtfNone = 0, - // All ids. - kMtfAll, - // All forward declared ids. - kMtfForwardDeclared, - // All type ids except for generated by OpTypeFunction. - kMtfTypeNonFunction, - // All labels. - kMtfLabel, - // All ids created by instructions which had type_id. - kMtfObject, - // All types generated by OpTypeFloat, OpTypeInt, OpTypeBool. - kMtfTypeScalar, - // All composite types. - kMtfTypeComposite, - // Boolean type or any vector type of it. - kMtfTypeBoolScalarOrVector, - // All float types or any vector floats type. - kMtfTypeFloatScalarOrVector, - // All int types or any vector int type. - kMtfTypeIntScalarOrVector, - // All types declared as return types in OpTypeFunction. - kMtfTypeReturnedByFunction, - // All composite objects. - kMtfComposite, - // All bool objects or vectors of bools. - kMtfBoolScalarOrVector, - // All float objects or vectors of float. - kMtfFloatScalarOrVector, - // All int objects or vectors of int. - kMtfIntScalarOrVector, - // All pointer types which point to composited. - kMtfTypePointerToComposite, - // Used by EncodeMtfRankHuffman. - kMtfGenericNonZeroRank, - // Handle space for ids of specific type. - kMtfIdOfTypeBegin = 0x10000, - // Handle space for ids generated by specific opcode. - kMtfIdGeneratedByOpcode = 0x20000, - // Handle space for ids of objects with type generated by specific opcode. - kMtfIdWithTypeGeneratedByOpcodeBegin = 0x30000, - // All vectors of specific component type. - kMtfVectorOfComponentTypeBegin = 0x40000, - // All vector types of specific size. - kMtfTypeVectorOfSizeBegin = 0x50000, - // All pointer types to specific type. - kMtfPointerToTypeBegin = 0x60000, - // All function types which return specific type. - kMtfFunctionTypeWithReturnTypeBegin = 0x70000, - // All function objects which return specific type. - kMtfFunctionWithReturnTypeBegin = 0x80000, - // Short id descriptor space (max 16-bit). - kMtfShortIdDescriptorSpaceBegin = 0x90000, - // Long id descriptor space (32-bit). - kMtfLongIdDescriptorSpaceBegin = 0x100000000, -}; - -// Signals that the value is not in the coding scheme and a fallback method -// needs to be used. -const uint64_t kMarkvNoneOfTheAbove = MarkvModel::GetMarkvNoneOfTheAbove(); - -// Mtf ranks smaller than this are encoded with Huffman coding. -const uint32_t kMtfSmallestRankEncodedByValue = 10; - -// Signals that the mtf rank is too large to be encoded with Huffman. -const uint32_t kMtfRankEncodedByValueSignal = - std::numeric_limits::max(); - -const size_t kCommentNumWhitespaces = 2; - -const size_t kByteBreakAfterInstIfLessThanUntilNextByte = 8; - -const uint32_t kShortDescriptorNumBits = 8; - // Custom hash function used to produce short descriptors. uint32_t ShortHashU32Array(const std::vector& words) { // The hash function is a sum of hashes of each word seeded by word index. @@ -163,7 +37,7 @@ uint32_t ShortHashU32Array(const std::vector& words) { for (uint32_t i = 0; i < words.size(); ++i) { val += (words[i] + i + 123) * kKnuthMulHash; } - return 1 + val % ((1 << kShortDescriptorNumBits) - 1); + return 1 + val % ((1 << MarkvCodec::kShortDescriptorNumBits) - 1); } // Returns a set of mtf rank codecs based on a plausible hand-coded @@ -174,7 +48,7 @@ GetMtfHuffmanCodecs() { std::unique_ptr> codec; - codec.reset(new HuffmanCodec(std::map({ + codec = MakeUnique>(std::map({ {0, 5}, {1, 40}, {2, 10}, @@ -185,11 +59,11 @@ GetMtfHuffmanCodecs() { {7, 3}, {8, 3}, {9, 3}, - {kMtfRankEncodedByValueSignal, 10}, - }))); + {MarkvCodec::kMtfRankEncodedByValueSignal, 10}, + })); codecs.emplace(kMtfAll, std::move(codec)); - codec.reset(new HuffmanCodec(std::map({ + codec = MakeUnique>(std::map({ {1, 50}, {2, 20}, {3, 5}, @@ -199,16 +73,57 @@ GetMtfHuffmanCodecs() { {7, 1}, {8, 1}, {9, 1}, - {kMtfRankEncodedByValueSignal, 10}, - }))); + {MarkvCodec::kMtfRankEncodedByValueSignal, 10}, + })); codecs.emplace(kMtfGenericNonZeroRank, std::move(codec)); return codecs; } +} // namespace + +const uint32_t MarkvCodec::kMarkvMagicNumber = 0x07230303; + +const uint32_t MarkvCodec::kMtfSmallestRankEncodedByValue = 10; + +const uint32_t MarkvCodec::kMtfRankEncodedByValueSignal = + std::numeric_limits::max(); + +const uint32_t MarkvCodec::kShortDescriptorNumBits = 8; + +const size_t MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte = 8; + +MarkvCodec::MarkvCodec(spv_const_context context, + spv_validator_options validator_options, + const MarkvModel* model) + : validator_options_(validator_options), + grammar_(context), + model_(model), + short_id_descriptors_(ShortHashU32Array), + mtf_huffman_codecs_(GetMtfHuffmanCodecs()), + context_(context) {} + +MarkvCodec::~MarkvCodec() { spvValidatorOptionsDestroy(validator_options_); } + +MarkvCodec::MarkvHeader::MarkvHeader() + : magic_number(MarkvCodec::kMarkvMagicNumber), + markv_version(MarkvCodec::GetMarkvVersion()) {} + +// Defines and returns current MARK-V version. +// static +uint32_t MarkvCodec::GetMarkvVersion() { + const uint32_t kVersionMajor = 1; + const uint32_t kVersionMinor = 4; + return kVersionMinor | (kVersionMajor << 16); +} + +size_t MarkvCodec::GetNumBitsToNextByte(size_t bit_pos) const { + return (8 - (bit_pos % 8)) % 8; +} + // Returns true if the opcode has a fixed number of operands. May return a // false negative. -bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) { +bool MarkvCodec::OpcodeHasFixedNumberOfOperands(SpvOp opcode) const { switch (opcode) { // TODO(atgoo@github.com) This is not a complete list. case SpvOpNop: @@ -272,612 +187,8 @@ bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) { return false; } -size_t GetNumBitsToNextByte(size_t bit_pos) { return (8 - (bit_pos % 8)) % 8; } - -// Defines and returns current MARK-V version. -uint32_t GetMarkvVersion() { - const uint32_t kVersionMajor = 1; - const uint32_t kVersionMinor = 4; - return kVersionMinor | (kVersionMajor << 16); -} - -class MarkvLogger { - public: - MarkvLogger(MarkvLogConsumer log_consumer, MarkvDebugConsumer debug_consumer) - : log_consumer_(log_consumer), debug_consumer_(debug_consumer) {} - - void AppendText(const std::string& str) { - Append(str); - use_delimiter_ = false; - } - - void AppendTextNewLine(const std::string& str) { - Append(str); - Append("\n"); - use_delimiter_ = false; - } - - void AppendBitSequence(const std::string& str) { - if (debug_consumer_) instruction_bits_ << str; - if (use_delimiter_) Append("-"); - Append(str); - use_delimiter_ = true; - } - - void AppendWhitespaces(size_t num) { - Append(std::string(num, ' ')); - use_delimiter_ = false; - } - - void NewLine() { - Append("\n"); - use_delimiter_ = false; - } - - bool DebugInstruction(const spv_parsed_instruction_t& inst) { - bool result = true; - if (debug_consumer_) { - result = debug_consumer_( - std::vector(inst.words, inst.words + inst.num_words), - instruction_bits_.str(), instruction_comment_.str()); - instruction_bits_.str(std::string()); - instruction_comment_.str(std::string()); - } - return result; - } - - private: - MarkvLogger(const MarkvLogger&) = delete; - MarkvLogger(MarkvLogger&&) = delete; - MarkvLogger& operator=(const MarkvLogger&) = delete; - MarkvLogger& operator=(MarkvLogger&&) = delete; - - void Append(const std::string& str) { - if (log_consumer_) log_consumer_(str); - if (debug_consumer_) instruction_comment_ << str; - } - - MarkvLogConsumer log_consumer_; - MarkvDebugConsumer debug_consumer_; - - std::stringstream instruction_bits_; - std::stringstream instruction_comment_; - - // If true a delimiter will be appended before the next bit sequence. - // Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0. - bool use_delimiter_ = false; -}; - -// Base class for MARK-V encoder and decoder. Contains common functionality -// such as: -// - Validator connection and validation state. -// - SPIR-V grammar and helper functions. -class MarkvCodecBase { - public: - virtual ~MarkvCodecBase() { spvValidatorOptionsDestroy(validator_options_); } - - MarkvCodecBase() = delete; - - protected: - struct MarkvHeader { - MarkvHeader() { - magic_number = kMarkvMagicNumber; - markv_version = GetMarkvVersion(); - markv_model = 0; - markv_length_in_bits = 0; - spirv_version = 0; - spirv_generator = 0; - } - - uint32_t magic_number; - uint32_t markv_version; - // Magic number to identify or verify MarkvModel used for encoding. - uint32_t markv_model; - uint32_t markv_length_in_bits; - uint32_t spirv_version; - uint32_t spirv_generator; - }; - - // |model| is owned by the caller, must be not null and valid during the - // lifetime of the codec. - explicit MarkvCodecBase(spv_const_context context, - spv_validator_options validator_options, - const MarkvModel* model) - : validator_options_(validator_options), - grammar_(context), - model_(model), - short_id_descriptors_(ShortHashU32Array), - mtf_huffman_codecs_(GetMtfHuffmanCodecs()), - context_(context), - vstate_(validator_options - ? new ValidationState_t(context, validator_options_) - : nullptr) {} - - // Validates a single instruction and updates validation state of the module. - // Does nothing and returns SPV_SUCCESS if validator was not created. - spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) { - if (!vstate_) return SPV_SUCCESS; - - return ValidateInstructionAndUpdateValidationState(vstate_.get(), &inst); - } - - // Returns instruction which created |id| or nullptr if such instruction was - // not registered. - const Instruction* FindDef(uint32_t id) const { - const auto it = id_to_def_instruction_.find(id); - if (it == id_to_def_instruction_.end()) return nullptr; - return it->second; - } - - // Returns type id of vector type component. - uint32_t GetVectorComponentType(uint32_t vector_type_id) const { - const Instruction* type_inst = FindDef(vector_type_id); - assert(type_inst); - assert(type_inst->opcode() == SpvOpTypeVector); - - const uint32_t component_type = - type_inst->word(type_inst->operands()[1].offset); - return component_type; - } - - // Returns mtf handle for ids of given type. - uint64_t GetMtfIdOfType(uint32_t type_id) const { - return kMtfIdOfTypeBegin + type_id; - } - - // Returns mtf handle for ids generated by given opcode. - uint64_t GetMtfIdGeneratedByOpcode(SpvOp opcode) const { - return kMtfIdGeneratedByOpcode + opcode; - } - - // Returns mtf handle for ids of type generated by given opcode. - uint64_t GetMtfIdWithTypeGeneratedByOpcode(SpvOp opcode) const { - return kMtfIdWithTypeGeneratedByOpcodeBegin + opcode; - } - - // Returns mtf handle for vectors of specific component type. - uint64_t GetMtfVectorOfComponentType(uint32_t type_id) const { - return kMtfVectorOfComponentTypeBegin + type_id; - } - - // Returns mtf handle for vector type of specific size. - uint64_t GetMtfTypeVectorOfSize(uint32_t size) const { - return kMtfTypeVectorOfSizeBegin + size; - } - - // Returns mtf handle for pointers to specific size. - uint64_t GetMtfPointerToType(uint32_t type_id) const { - return kMtfPointerToTypeBegin + type_id; - } - - // Returns mtf handle for function types with given return type. - uint64_t GetMtfFunctionTypeWithReturnType(uint32_t type_id) const { - return kMtfFunctionTypeWithReturnTypeBegin + type_id; - } - - // Returns mtf handle for functions with given return type. - uint64_t GetMtfFunctionWithReturnType(uint32_t type_id) const { - return kMtfFunctionWithReturnTypeBegin + type_id; - } - - // Returns mtf handle for the given long id descriptor. - uint64_t GetMtfLongIdDescriptor(uint32_t descriptor) const { - return kMtfLongIdDescriptorSpaceBegin + descriptor; - } - - // Returns mtf handle for the given short id descriptor. - uint64_t GetMtfShortIdDescriptor(uint32_t descriptor) const { - return kMtfShortIdDescriptorSpaceBegin + descriptor; - } - - // Process data from the current instruction. This would update MTFs and - // other data containers. - void ProcessCurInstruction(); - - // Returns move-to-front handle to be used for the current operand slot. - // Mtf handle is chosen based on a set of rules defined by SPIR-V grammar. - uint64_t GetRuleBasedMtf(); - - // Returns words of the current instruction. Decoder has a different - // implementation and the array is valid only until the previously decoded - // word. - virtual const uint32_t* GetInstWords() const { return inst_.words; } - - // Returns the opcode of the previous instruction. - SpvOp GetPrevOpcode() const { - if (instructions_.empty()) return SpvOpNop; - - return instructions_.back()->opcode(); - } - - // Returns diagnostic stream, position index is set to instruction number. - DiagnosticStream Diag(spv_result_t error_code) const { - return DiagnosticStream({0, 0, instructions_.size()}, context_->consumer, - error_code); - } - - // Returns current id bound. - uint32_t GetIdBound() const { return id_bound_; } - - // Sets current id bound, expected to be no lower than the previous one. - void SetIdBound(uint32_t id_bound) { - assert(id_bound >= id_bound_); - id_bound_ = id_bound; - if (vstate_) vstate_->setIdBound(id_bound); - } - - // Returns Huffman codec for ranks of the mtf with given |handle|. - // Different mtfs can use different rank distributions. - // May return nullptr if the codec doesn't exist. - const spvutils::HuffmanCodec* GetMtfHuffmanCodec( - uint64_t handle) const { - const auto it = mtf_huffman_codecs_.find(handle); - if (it == mtf_huffman_codecs_.end()) return nullptr; - return it->second.get(); - } - - // Promotes id in all move-to-front sequences if ids can be shared by multiple - // sequences. - void PromoteIfNeeded(uint32_t id) { - if (!model_->AnyDescriptorHasCodingScheme() && - model_->id_fallback_strategy() == - MarkvModel::IdFallbackStrategy::kShortDescriptor) { - // Move-to-front sequences do not share ids. Nothing to do. - return; - } - multi_mtf_.Promote(id); - } - - spv_validator_options validator_options_ = nullptr; - const libspirv::AssemblyGrammar grammar_; - MarkvHeader header_; - - // MARK-V model, not owned. - const MarkvModel* model_ = nullptr; - - // Current instruction, current operand and current operand index. - spv_parsed_instruction_t inst_; - spv_parsed_operand_t operand_; - uint32_t operand_index_; - - // Maps a result ID to its type ID. By convention: - // - a result ID that is a type definition maps to itself. - // - a result ID without a type maps to 0. (E.g. for OpLabel) - std::unordered_map id_to_type_id_; - - // Container for all move-to-front sequences. - MultiMoveToFront multi_mtf_; - - // Id of the current function or zero if outside of function. - uint32_t cur_function_id_ = 0; - - // Return type of the current function. - uint32_t cur_function_return_type_ = 0; - - // Remaining function parameter types. This container is filled on OpFunction, - // and drained on OpFunctionParameter. - std::list remaining_function_parameter_types_; - - // List of ids local to the current function. - std::vector ids_local_to_cur_function_; - - // List of instructions in the order they are given in the module. - std::vector> instructions_; - - // Container/computer for long (32-bit) id descriptors. - IdDescriptorCollection long_id_descriptors_; - - // Container/computer for short id descriptors. - // Short descriptors are stored in uint32_t, but their actual bit width is - // defined with kShortDescriptorNumBits. - // It doesn't seem logical to have a different computer for short id - // descriptors, since one could actually map/truncate long descriptors. - // But as short descriptors have collisions, the efficiency of - // compression depends on the collision pattern, and short descriptors - // produced by function ShortHashU32Array have been empirically proven to - // produce better results. - IdDescriptorCollection short_id_descriptors_; - - // Huffman codecs for move-to-front ranks. The map key is mtf handle. Doesn't - // need to contain a different codec for every handle as most use one and the - // same. - std::map>> - mtf_huffman_codecs_; - - // If not nullptr, codec will log comments on the compression process. - std::unique_ptr logger_; - - private: - spv_const_context context_ = nullptr; - - std::unique_ptr vstate_; - - // Maps result id to the instruction which defined it. - std::unordered_map id_to_def_instruction_; - - uint32_t id_bound_ = 1; -}; - -// SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and -// EncodeInstruction which can be used as callback by spvBinaryParse. -// Encoded binary is written to an internally maintained bitstream. -// After the last instruction is encoded, the resulting MARK-V binary can be -// acquired by calling GetMarkvBinary(). -// The encoder uses SPIR-V validator to keep internal state, therefore -// SPIR-V binary needs to be able to pass validator checks. -// CreateCommentsLogger() can be used to enable the encoder to write comments -// on how encoding was done, which can later be accessed with GetComments(). -class MarkvEncoder : public MarkvCodecBase { - public: - // |model| is owned by the caller, must be not null and valid during the - // lifetime of MarkvEncoder. - MarkvEncoder(spv_const_context context, const MarkvCodecOptions& options, - const MarkvModel* model) - : MarkvCodecBase(context, GetValidatorOptions(options), model), - options_(options) { - (void)options_; - } - - // Writes data from SPIR-V header to MARK-V header. - spv_result_t EncodeHeader(spv_endianness_t /* endian */, uint32_t /* magic */, - uint32_t version, uint32_t generator, - uint32_t id_bound, uint32_t /* schema */) { - SetIdBound(id_bound); - header_.spirv_version = version; - header_.spirv_generator = generator; - return SPV_SUCCESS; - } - - // Creates an internal logger which writes comments on the encoding process. - void CreateLogger(MarkvLogConsumer log_consumer, - MarkvDebugConsumer debug_consumer) { - logger_.reset(new MarkvLogger(log_consumer, debug_consumer)); - writer_.SetCallback( - [this](const std::string& str) { logger_->AppendBitSequence(str); }); - } - - // Encodes SPIR-V instruction to MARK-V and writes to bit stream. - // Operation can fail if the instruction fails to pass the validator or if - // the encoder stubmles on something unexpected. - spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst); - - // Concatenates MARK-V header and the bit stream with encoded instructions - // into a single buffer and returns it as spv_markv_binary. The returned - // value is owned by the caller and needs to be destroyed with - // spvMarkvBinaryDestroy(). - std::vector GetMarkvBinary() { - header_.markv_length_in_bits = - static_cast(sizeof(header_) * 8 + writer_.GetNumBits()); - header_.markv_model = - (model_->model_type() << 16) | model_->model_version(); - - const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes(); - std::vector markv(num_bytes); - - assert(writer_.GetData()); - std::memcpy(markv.data(), &header_, sizeof(header_)); - std::memcpy(markv.data() + sizeof(header_), writer_.GetData(), - writer_.GetDataSizeBytes()); - return markv; - } - - // Optionally adds disassembly to the comments. - // Disassembly should contain all instructions in the module separated by - // \n, and no header. - void SetDisassembly(std::string&& disassembly) { - disassembly_.reset(new std::stringstream(std::move(disassembly))); - } - - // Extracts the next instruction line from the disassembly and logs it. - void LogDisassemblyInstruction() { - if (logger_ && disassembly_) { - std::string line; - std::getline(*disassembly_, line, '\n'); - logger_->AppendTextNewLine(line); - } - } - - private: - // Creates and returns validator options. Returned value owned by the caller. - static spv_validator_options GetValidatorOptions( - const MarkvCodecOptions& options) { - return options.validate_spirv_binary ? spvValidatorOptionsCreate() - : nullptr; - } - - // Writes a single word to bit stream. operand_.type determines if the word is - // encoded and how. - spv_result_t EncodeNonIdWord(uint32_t word); - - // Writes both opcode and num_operands as a single code. - // Returns SPV_UNSUPPORTED iff no suitable codec was found. - spv_result_t EncodeOpcodeAndNumOperands(uint32_t opcode, - uint32_t num_operands); - - // Writes mtf rank to bit stream. |mtf| is used to determine the codec - // scheme. |fallback_method| is used if no codec defined for |mtf|. - spv_result_t EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf, - uint64_t fallback_method); - - // Writes id using coding based on mtf associated with the id descriptor. - // Returns SPV_UNSUPPORTED iff fallback method needs to be used. - spv_result_t EncodeIdWithDescriptor(uint32_t id); - - // Writes id using coding based on the given |mtf|, which is expected to - // contain the given |id|. - spv_result_t EncodeExistingId(uint64_t mtf, uint32_t id); - - // Writes type id of the current instruction if can't be inferred. - spv_result_t EncodeTypeId(); - - // Writes result id of the current instruction if can't be inferred. - spv_result_t EncodeResultId(); - - // Writes ids which are neither type nor result ids. - spv_result_t EncodeRefId(uint32_t id); - - // Writes bits to the stream until the beginning of the next byte if the - // number of bits until the next byte is less than |byte_break_if_less_than|. - void AddByteBreak(size_t byte_break_if_less_than); - - // Encodes a literal number operand and writes it to the bit stream. - spv_result_t EncodeLiteralNumber(const spv_parsed_operand_t& operand); - - MarkvCodecOptions options_; - - // Bit stream where encoded instructions are written. - BitWriterWord64 writer_; - - // If not nullptr, disassembled instruction lines will be written to comments. - // Format: \n separated instruction lines, no header. - std::unique_ptr disassembly_; -}; - -// Decodes MARK-V buffers written by MarkvEncoder. -class MarkvDecoder : public MarkvCodecBase { - public: - // |model| is owned by the caller, must be not null and valid during the - // lifetime of MarkvEncoder. - MarkvDecoder(spv_const_context context, const std::vector& markv, - const MarkvCodecOptions& options, const MarkvModel* model) - : MarkvCodecBase(context, GetValidatorOptions(options), model), - options_(options), - reader_(markv) { - (void)options_; - SetIdBound(1); - parsed_operands_.reserve(25); - inst_words_.reserve(25); - } - - // Creates an internal logger which writes comments on the decoding process. - void CreateLogger(MarkvLogConsumer log_consumer, - MarkvDebugConsumer debug_consumer) { - logger_.reset(new MarkvLogger(log_consumer, debug_consumer)); - } - - // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|. - // Can be called only once. Fails if data of wrong format or ends prematurely, - // of if validation fails. - spv_result_t DecodeModule(std::vector* spirv_binary); - - private: - // Describes the format of a typed literal number. - struct NumberType { - spv_number_kind_t type; - uint32_t bit_width; - }; - - // Creates and returns validator options. Returned value owned by the caller. - static spv_validator_options GetValidatorOptions( - const MarkvCodecOptions& options) { - return options.validate_spirv_binary ? spvValidatorOptionsCreate() - : nullptr; - } - - // Reads a single bit from reader_. The read bit is stored in |bit|. - // Returns false iff reader_ fails. - bool ReadBit(bool* bit) { - uint64_t bits = 0; - const bool result = reader_.ReadBits(&bits, 1); - if (result) *bit = bits ? true : false; - return result; - }; - - // Returns ReadBit bound to the class object. - std::function GetReadBitCallback() { - return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1); - } - - // Reads a single non-id word from bit stream. operand_.type determines if - // the word needs to be decoded and how. - spv_result_t DecodeNonIdWord(uint32_t* word); - - // Reads and decodes both opcode and num_operands as a single code. - // Returns SPV_UNSUPPORTED iff no suitable codec was found. - spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode, - uint32_t* num_operands); - - // Reads mtf rank from bit stream. |mtf| is used to determine the codec - // scheme. |fallback_method| is used if no codec defined for |mtf|. - spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method, - uint32_t* rank); - - // Reads id using coding based on mtf associated with the id descriptor. - // Returns SPV_UNSUPPORTED iff fallback method needs to be used. - spv_result_t DecodeIdWithDescriptor(uint32_t* id); - - // Reads id using coding based on the given |mtf|, which is expected to - // contain the needed |id|. - spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id); - - // Reads type id of the current instruction if can't be inferred. - spv_result_t DecodeTypeId(); - - // Reads result id of the current instruction if can't be inferred. - spv_result_t DecodeResultId(); - - // Reads id which is neither type nor result id. - spv_result_t DecodeRefId(uint32_t* id); - - // Reads and discards bits until the beginning of the next byte if the - // number of bits until the next byte is less than |byte_break_if_less_than|. - bool ReadToByteBreak(size_t byte_break_if_less_than); - - // Returns instruction words decoded up to this point. - const uint32_t* GetInstWords() const override { return inst_words_.data(); } - - // Reads a literal number as it is described in |operand| from the bit stream, - // decodes and writes it to spirv_. - spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand); - - // Reads instruction from bit stream, decodes and validates it. - // Decoded instruction is valid until the next call of DecodeInstruction(). - spv_result_t DecodeInstruction(); - - // Read operand from the stream decodes and validates it. - spv_result_t DecodeOperand(size_t operand_offset, - const spv_operand_type_t type, - spv_operand_pattern_t* expected_operands); - - // Records the numeric type for an operand according to the type information - // associated with the given non-zero type Id. This can fail if the type Id - // is not a type Id, or if the type Id does not reference a scalar numeric - // type. On success, return SPV_SUCCESS and populates the num_words, - // number_kind, and number_bit_width fields of parsed_operand. - spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand, - uint32_t type_id); - - // Records the number type for the current instruction, if it generates a - // type. For types that aren't scalar numbers, record something with number - // kind SPV_NUMBER_NONE. - void RecordNumberType(); - - MarkvCodecOptions options_; - - // Temporary sink where decoded SPIR-V words are written. Once it contains the - // entire module, the container is moved and returned. - std::vector spirv_; - - // Bit stream containing encoded data. - BitReaderWord64 reader_; - - // Temporary storage for operands of the currently parsed instruction. - // Valid until next DecodeInstruction call. - std::vector parsed_operands_; - - // Temporary storage for current instruction words. - // Valid until next DecodeInstruction call. - std::vector inst_words_; - - // Maps a type ID to its number type description. - std::unordered_map type_id_to_number_type_info_; - - // Maps an ExtInstImport id to the extended instruction type. - std::unordered_map import_id_to_ext_inst_type_; -}; - -void MarkvCodecBase::ProcessCurInstruction() { - instructions_.emplace_back(new Instruction(&inst_)); +void MarkvCodec::ProcessCurInstruction() { + instructions_.emplace_back(new val::Instruction(&inst_)); const SpvOp opcode = SpvOp(inst_.opcode); @@ -901,7 +212,7 @@ void MarkvCodecBase::ProcessCurInstruction() { // Store function parameter types in a queue, so that we know which types // to expect in the following OpFunctionParameter instructions. - const Instruction* def_inst = FindDef(inst_.words[4]); + const val::Instruction* def_inst = FindDef(inst_.words[4]); assert(def_inst); assert(def_inst->opcode() == SpvOpTypeFunction); for (uint32_t i = 3; i < def_inst->words().size(); ++i) { @@ -1000,7 +311,7 @@ void MarkvCodecBase::ProcessCurInstruction() { } if (inst_.type_id) { - const Instruction* type_inst = FindDef(inst_.type_id); + const val::Instruction* type_inst = FindDef(inst_.type_id); assert(type_inst); multi_mtf_.Insert(kMtfObject, inst_.result_id); @@ -1078,7 +389,7 @@ void MarkvCodecBase::ProcessCurInstruction() { } } -uint64_t MarkvCodecBase::GetRuleBasedMtf() { +uint64_t MarkvCodec::GetRuleBasedMtf() { // This function is only called for id operands (but not result ids). assert(spvIsIdType(operand_.type) || operand_.type == SPV_OPERAND_TYPE_OPTIONAL_ID); @@ -1243,7 +554,7 @@ uint64_t MarkvCodecBase::GetRuleBasedMtf() { if (operand_index_ == 1) { const uint32_t pointer_id = GetInstWords()[1]; const uint32_t pointer_type = id_to_type_id_.at(pointer_id); - const Instruction* pointer_inst = FindDef(pointer_type); + const val::Instruction* pointer_inst = FindDef(pointer_type); assert(pointer_inst); assert(pointer_inst->opcode() == SpvOpTypePointer); const uint32_t data_type = @@ -1290,7 +601,7 @@ uint64_t MarkvCodecBase::GetRuleBasedMtf() { case SpvOpConstantComposite: { if (operand_index_ == 0) return kMtfTypeComposite; if (operand_index_ >= 2) { - const Instruction* composite_type_inst = FindDef(inst_.type_id); + const val::Instruction* composite_type_inst = FindDef(inst_.type_id); assert(composite_type_inst); if (composite_type_inst->opcode() == SpvOpTypeVector) { return GetMtfIdOfType(composite_type_inst->word(2)); @@ -1424,13 +735,13 @@ uint64_t MarkvCodecBase::GetRuleBasedMtf() { if (operand_index_ >= 3) { const uint32_t function_id = GetInstWords()[3]; - const Instruction* function_inst = FindDef(function_id); + const val::Instruction* function_inst = FindDef(function_id); if (!function_inst) return kMtfObject; assert(function_inst->opcode() == SpvOpFunction); const uint32_t function_type_id = function_inst->word(4); - const Instruction* function_type_inst = FindDef(function_type_id); + const val::Instruction* function_type_inst = FindDef(function_type_id); assert(function_type_inst); assert(function_type_inst->opcode() == SpvOpTypeFunction); @@ -1478,1443 +789,5 @@ uint64_t MarkvCodecBase::GetRuleBasedMtf() { return kMtfNone; } -spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) { - auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_); - - if (codec) { - uint64_t bits = 0; - size_t num_bits = 0; - if (codec->Encode(word, &bits, &num_bits)) { - // Encoding successful. - writer_.WriteBits(bits, num_bits); - return SPV_SUCCESS; - } else { - // Encoding failed, write kMarkvNoneOfTheAbove flag. - if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits)) - return Diag(SPV_ERROR_INTERNAL) - << "Non-id word Huffman table for " - << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index " - << operand_index_ << " is missing kMarkvNoneOfTheAbove"; - writer_.WriteBits(bits, num_bits); - } - } - - // Fallback encoding. - const size_t chunk_length = - model_->GetOperandVariableWidthChunkLength(operand_.type); - if (chunk_length) { - writer_.WriteVariableWidthU32(word, chunk_length); - } else { - writer_.WriteUnencoded(word); - } - return SPV_SUCCESS; -} - -spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) { - auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_); - - if (codec) { - uint64_t decoded_value = 0; - if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) - return Diag(SPV_ERROR_INVALID_BINARY) - << "Failed to decode non-id word with Huffman"; - - if (decoded_value != kMarkvNoneOfTheAbove) { - // The word decoded successfully. - *word = uint32_t(decoded_value); - assert(*word == decoded_value); - return SPV_SUCCESS; - } - - // Received kMarkvNoneOfTheAbove signal, use fallback decoding. - } - - const size_t chunk_length = - model_->GetOperandVariableWidthChunkLength(operand_.type); - if (chunk_length) { - if (!reader_.ReadVariableWidthU32(word, chunk_length)) - return Diag(SPV_ERROR_INVALID_BINARY) - << "Failed to decode non-id word with varint"; - } else { - if (!reader_.ReadUnencoded(word)) - return Diag(SPV_ERROR_INVALID_BINARY) - << "Failed to read unencoded non-id word"; - } - return SPV_SUCCESS; -} - -spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(uint32_t opcode, - uint32_t num_operands) { - uint64_t bits = 0; - size_t num_bits = 0; - - const uint32_t word = opcode | (num_operands << 16); - - // First try to use the Markov chain codec. - auto* codec = - model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode()); - if (codec) { - if (codec->Encode(word, &bits, &num_bits)) { - // The word was successfully encoded into bits/num_bits. - writer_.WriteBits(bits, num_bits); - return SPV_SUCCESS; - } else { - // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove - // and use fallback encoding. - if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits)) - return Diag(SPV_ERROR_INTERNAL) - << "opcode_and_num_operands Huffman table for " - << spvOpcodeString(GetPrevOpcode()) - << "is missing kMarkvNoneOfTheAbove"; - writer_.WriteBits(bits, num_bits); - } - } - - // Fallback to base-rate codec. - codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop); - assert(codec); - if (codec->Encode(word, &bits, &num_bits)) { - // The word was successfully encoded into bits/num_bits. - writer_.WriteBits(bits, num_bits); - return SPV_SUCCESS; - } else { - // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove - // and return false. - if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits)) - return Diag(SPV_ERROR_INTERNAL) - << "Global opcode_and_num_operands Huffman table is missing " - << "kMarkvNoneOfTheAbove"; - writer_.WriteBits(bits, num_bits); - return SPV_UNSUPPORTED; - } -} - -spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands( - uint32_t* opcode, uint32_t* num_operands) { - // First try to use the Markov chain codec. - auto* codec = - model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode()); - if (codec) { - uint64_t decoded_value = 0; - if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) - return Diag(SPV_ERROR_INTERNAL) - << "Failed to decode opcode_and_num_operands, previous opcode is " - << spvOpcodeString(GetPrevOpcode()); - - if (decoded_value != kMarkvNoneOfTheAbove) { - // The word was successfully decoded. - *opcode = uint32_t(decoded_value & 0xFFFF); - *num_operands = uint32_t(decoded_value >> 16); - return SPV_SUCCESS; - } - - // Received kMarkvNoneOfTheAbove signal, use fallback decoding. - } - - // Fallback to base-rate codec. - codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop); - assert(codec); - uint64_t decoded_value = 0; - if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) - return Diag(SPV_ERROR_INTERNAL) - << "Failed to decode opcode_and_num_operands with global codec"; - - if (decoded_value == kMarkvNoneOfTheAbove) { - // Received kMarkvNoneOfTheAbove signal, fallback further. - return SPV_UNSUPPORTED; - } - - *opcode = uint32_t(decoded_value & 0xFFFF); - *num_operands = uint32_t(decoded_value >> 16); - return SPV_SUCCESS; -} - -spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf, - uint64_t fallback_method) { - const auto* codec = GetMtfHuffmanCodec(mtf); - if (!codec) { - assert(fallback_method != kMtfNone); - codec = GetMtfHuffmanCodec(fallback_method); - } - - if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank"; - - uint64_t bits = 0; - size_t num_bits = 0; - if (rank < kMtfSmallestRankEncodedByValue) { - // Encode using Huffman coding. - if (!codec->Encode(rank, &bits, &num_bits)) - return Diag(SPV_ERROR_INTERNAL) - << "Failed to encode MTF rank with Huffman"; - - writer_.WriteBits(bits, num_bits); - } else { - // Encode by value. - if (!codec->Encode(kMtfRankEncodedByValueSignal, &bits, &num_bits)) - return Diag(SPV_ERROR_INTERNAL) - << "Failed to encode kMtfRankEncodedByValueSignal"; - - writer_.WriteBits(bits, num_bits); - writer_.WriteVariableWidthU32(rank - kMtfSmallestRankEncodedByValue, - model_->mtf_rank_chunk_length()); - } - return SPV_SUCCESS; -} - -spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf, - uint32_t fallback_method, - uint32_t* rank) { - const auto* codec = GetMtfHuffmanCodec(mtf); - if (!codec) { - assert(fallback_method != kMtfNone); - codec = GetMtfHuffmanCodec(fallback_method); - } - - if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank"; - - uint32_t decoded_value = 0; - if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) - return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman"; - - if (decoded_value == kMtfRankEncodedByValueSignal) { - // Decode by value. - if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length())) - return Diag(SPV_ERROR_INTERNAL) - << "Failed to decode MTF rank with varint"; - *rank += kMtfSmallestRankEncodedByValue; - } else { - // Decode using Huffman coding. - assert(decoded_value < kMtfSmallestRankEncodedByValue); - *rank = decoded_value; - } - return SPV_SUCCESS; -} - -spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) { - // Get the descriptor for id. - const uint32_t long_descriptor = long_id_descriptors_.GetDescriptor(id); - auto* codec = - model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_); - uint64_t bits = 0; - size_t num_bits = 0; - uint64_t mtf = kMtfNone; - if (long_descriptor && codec && - codec->Encode(long_descriptor, &bits, &num_bits)) { - // If the descriptor exists and is in the table, write the descriptor and - // proceed to encoding the rank. - writer_.WriteBits(bits, num_bits); - mtf = GetMtfLongIdDescriptor(long_descriptor); - } else { - if (codec) { - // The descriptor doesn't exist or we have no coding for it. Write - // kMarkvNoneOfTheAbove and go to fallback method. - if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits)) - return Diag(SPV_ERROR_INTERNAL) - << "Descriptor Huffman table for " - << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index " - << operand_index_ << " is missing kMarkvNoneOfTheAbove"; - - writer_.WriteBits(bits, num_bits); - } - - if (model_->id_fallback_strategy() != - MarkvModel::IdFallbackStrategy::kShortDescriptor) { - return SPV_UNSUPPORTED; - } - - const uint32_t short_descriptor = short_id_descriptors_.GetDescriptor(id); - writer_.WriteBits(short_descriptor, kShortDescriptorNumBits); - - if (short_descriptor == 0) { - // Forward declared id. - return SPV_UNSUPPORTED; - } - - mtf = GetMtfShortIdDescriptor(short_descriptor); - } - - // Descriptor has been encoded. Now encode the rank of the id in the - // associated mtf sequence. - return EncodeExistingId(mtf, id); -} - -spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) { - auto* codec = - model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_); - - uint64_t mtf = kMtfNone; - if (codec) { - uint64_t decoded_value = 0; - if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) - return Diag(SPV_ERROR_INTERNAL) - << "Failed to decode descriptor with Huffman"; - - if (decoded_value != kMarkvNoneOfTheAbove) { - const uint32_t long_descriptor = uint32_t(decoded_value); - mtf = GetMtfLongIdDescriptor(long_descriptor); - } - } - - if (mtf == kMtfNone) { - if (model_->id_fallback_strategy() != - MarkvModel::IdFallbackStrategy::kShortDescriptor) { - return SPV_UNSUPPORTED; - } - - uint64_t decoded_value = 0; - if (!reader_.ReadBits(&decoded_value, kShortDescriptorNumBits)) - return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor"; - const uint32_t short_descriptor = uint32_t(decoded_value); - if (short_descriptor == 0) { - // Forward declared id. - return SPV_UNSUPPORTED; - } - mtf = GetMtfShortIdDescriptor(short_descriptor); - } - - return DecodeExistingId(mtf, id); -} - -spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) { - assert(multi_mtf_.GetSize(mtf) > 0); - if (multi_mtf_.GetSize(mtf) == 1) { - // If the sequence has only one element no need to write rank, the decoder - // would make the same decision. - return SPV_SUCCESS; - } - - uint32_t rank = 0; - if (!multi_mtf_.RankFromValue(mtf, id, &rank)) - return Diag(SPV_ERROR_INTERNAL) << "Id is not in the MTF sequence"; - - return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank); -} - -spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) { - assert(multi_mtf_.GetSize(mtf) > 0); - *id = 0; - - uint32_t rank = 0; - - if (multi_mtf_.GetSize(mtf) == 1) { - rank = 1; - } else { - const spv_result_t result = - DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank); - if (result != SPV_SUCCESS) return result; - } - - assert(rank); - if (!multi_mtf_.ValueFromRank(mtf, rank, id)) - return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds"; - - return SPV_SUCCESS; -} - -spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) { - { - // Try to encode using id descriptor mtfs. - const spv_result_t result = EncodeIdWithDescriptor(id); - if (result != SPV_UNSUPPORTED) return result; - // If can't be done continue with other methods. - } - - const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction( - SpvOp(inst_.opcode))(operand_index_); - uint32_t rank = 0; - - if (model_->id_fallback_strategy() == - MarkvModel::IdFallbackStrategy::kRuleBased) { - // Encode using rule-based mtf. - uint64_t mtf = GetRuleBasedMtf(); - - if (mtf != kMtfNone && !can_forward_declare) { - assert(multi_mtf_.HasValue(kMtfAll, id)); - return EncodeExistingId(mtf, id); - } - - if (mtf == kMtfNone) mtf = kMtfAll; - - if (!multi_mtf_.RankFromValue(mtf, id, &rank)) { - // This is the first occurrence of a forward declared id. - multi_mtf_.Insert(kMtfAll, id); - multi_mtf_.Insert(kMtfForwardDeclared, id); - if (mtf != kMtfAll) multi_mtf_.Insert(mtf, id); - rank = 0; - } - - return EncodeMtfRankHuffman(rank, mtf, kMtfAll); - } else { - assert(can_forward_declare); - - if (!multi_mtf_.RankFromValue(kMtfForwardDeclared, id, &rank)) { - // This is the first occurrence of a forward declared id. - multi_mtf_.Insert(kMtfForwardDeclared, id); - rank = 0; - } - - writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length()); - return SPV_SUCCESS; - } -} - -spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) { - { - const spv_result_t result = DecodeIdWithDescriptor(id); - if (result != SPV_UNSUPPORTED) return result; - } - - const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction( - SpvOp(inst_.opcode))(operand_index_); - uint32_t rank = 0; - *id = 0; - - if (model_->id_fallback_strategy() == - MarkvModel::IdFallbackStrategy::kRuleBased) { - uint64_t mtf = GetRuleBasedMtf(); - if (mtf != kMtfNone && !can_forward_declare) { - return DecodeExistingId(mtf, id); - } - - if (mtf == kMtfNone) mtf = kMtfAll; - { - const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank); - if (result != SPV_SUCCESS) return result; - } - - if (rank == 0) { - // This is the first occurrence of a forward declared id. - *id = GetIdBound(); - SetIdBound(*id + 1); - multi_mtf_.Insert(kMtfAll, *id); - multi_mtf_.Insert(kMtfForwardDeclared, *id); - if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id); - } else { - if (!multi_mtf_.ValueFromRank(mtf, rank, id)) - return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds"; - } - } else { - assert(can_forward_declare); - - if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length())) - return Diag(SPV_ERROR_INTERNAL) - << "Failed to decode MTF rank with varint"; - - if (rank == 0) { - // This is the first occurrence of a forward declared id. - *id = GetIdBound(); - SetIdBound(*id + 1); - multi_mtf_.Insert(kMtfForwardDeclared, *id); - } else { - if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id)) - return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds"; - } - } - assert(*id); - return SPV_SUCCESS; -} - -spv_result_t MarkvEncoder::EncodeTypeId() { - if (inst_.opcode == SpvOpFunctionParameter) { - assert(!remaining_function_parameter_types_.empty()); - assert(inst_.type_id == remaining_function_parameter_types_.front()); - remaining_function_parameter_types_.pop_front(); - return SPV_SUCCESS; - } - - { - // Try to encode using id descriptor mtfs. - const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id); - if (result != SPV_UNSUPPORTED) return result; - // If can't be done continue with other methods. - } - - assert(model_->id_fallback_strategy() == - MarkvModel::IdFallbackStrategy::kRuleBased); - - uint64_t mtf = GetRuleBasedMtf(); - assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))( - operand_index_)); - - if (mtf == kMtfNone) { - mtf = kMtfTypeNonFunction; - // Function types should have been handled by GetRuleBasedMtf. - assert(inst_.opcode != SpvOpFunction); - } - - return EncodeExistingId(mtf, inst_.type_id); -} - -spv_result_t MarkvDecoder::DecodeTypeId() { - if (inst_.opcode == SpvOpFunctionParameter) { - assert(!remaining_function_parameter_types_.empty()); - inst_.type_id = remaining_function_parameter_types_.front(); - remaining_function_parameter_types_.pop_front(); - return SPV_SUCCESS; - } - - { - const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id); - if (result != SPV_UNSUPPORTED) return result; - } - - assert(model_->id_fallback_strategy() == - MarkvModel::IdFallbackStrategy::kRuleBased); - - uint64_t mtf = GetRuleBasedMtf(); - assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))( - operand_index_)); - - if (mtf == kMtfNone) { - mtf = kMtfTypeNonFunction; - // Function types should have been handled by GetRuleBasedMtf. - assert(inst_.opcode != SpvOpFunction); - } - - return DecodeExistingId(mtf, &inst_.type_id); -} - -spv_result_t MarkvEncoder::EncodeResultId() { - uint32_t rank = 0; - - const uint64_t num_still_forward_declared = - multi_mtf_.GetSize(kMtfForwardDeclared); - - if (num_still_forward_declared) { - // We write the rank only if kMtfForwardDeclared is not empty. If it is - // empty the decoder knows that there are no forward declared ids to expect. - if (multi_mtf_.RankFromValue(kMtfForwardDeclared, inst_.result_id, &rank)) { - // This is a definition of a forward declared id. We can remove the id - // from kMtfForwardDeclared. - if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id)) - return Diag(SPV_ERROR_INTERNAL) - << "Failed to remove id from kMtfForwardDeclared"; - writer_.WriteBits(1, 1); - writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length()); - } else { - rank = 0; - writer_.WriteBits(0, 1); - } - } - - if (model_->id_fallback_strategy() == - MarkvModel::IdFallbackStrategy::kRuleBased) { - if (!rank) { - multi_mtf_.Insert(kMtfAll, inst_.result_id); - } - } - - return SPV_SUCCESS; -} - -spv_result_t MarkvDecoder::DecodeResultId() { - uint32_t rank = 0; - - const uint64_t num_still_forward_declared = - multi_mtf_.GetSize(kMtfForwardDeclared); - - if (num_still_forward_declared) { - // Some ids were forward declared. Check if this id is one of them. - uint64_t id_was_forward_declared; - if (!reader_.ReadBits(&id_was_forward_declared, 1)) - return Diag(SPV_ERROR_INVALID_BINARY) - << "Failed to read id_was_forward_declared flag"; - - if (id_was_forward_declared) { - if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length())) - return Diag(SPV_ERROR_INVALID_BINARY) - << "Failed to read MTF rank of forward declared id"; - - if (rank) { - // The id was forward declared, recover it from kMtfForwardDeclared. - if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, - &inst_.result_id)) - return Diag(SPV_ERROR_INTERNAL) - << "Forward declared MTF rank is out of bounds"; - - // We can now remove the id from kMtfForwardDeclared. - if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id)) - return Diag(SPV_ERROR_INTERNAL) - << "Failed to remove id from kMtfForwardDeclared"; - } - } - } - - if (inst_.result_id == 0) { - // The id was not forward declared, issue a new id. - inst_.result_id = GetIdBound(); - SetIdBound(inst_.result_id + 1); - } - - if (model_->id_fallback_strategy() == - MarkvModel::IdFallbackStrategy::kRuleBased) { - if (!rank) { - multi_mtf_.Insert(kMtfAll, inst_.result_id); - } - } - - return SPV_SUCCESS; -} - -spv_result_t MarkvEncoder::EncodeLiteralNumber( - const spv_parsed_operand_t& operand) { - if (operand.number_bit_width <= 32) { - const uint32_t word = inst_.words[operand.offset]; - return EncodeNonIdWord(word); - } else { - assert(operand.number_bit_width <= 64); - const uint64_t word = uint64_t(inst_.words[operand.offset]) | - (uint64_t(inst_.words[operand.offset + 1]) << 32); - if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { - writer_.WriteVariableWidthU64(word, model_->u64_chunk_length()); - } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { - int64_t val = 0; - std::memcpy(&val, &word, 8); - writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(), - model_->s64_block_exponent()); - } else if (operand.number_kind == SPV_NUMBER_FLOATING) { - writer_.WriteUnencoded(word); - } else { - return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length"; - } - } - return SPV_SUCCESS; -} - -spv_result_t MarkvDecoder::DecodeLiteralNumber( - const spv_parsed_operand_t& operand) { - if (operand.number_bit_width <= 32) { - uint32_t word = 0; - const spv_result_t result = DecodeNonIdWord(&word); - if (result != SPV_SUCCESS) return result; - inst_words_.push_back(word); - } else { - assert(operand.number_bit_width <= 64); - uint64_t word = 0; - if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { - if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length())) - return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64"; - } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { - int64_t val = 0; - if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(), - model_->s64_block_exponent())) - return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64"; - std::memcpy(&word, &val, 8); - } else if (operand.number_kind == SPV_NUMBER_FLOATING) { - if (!reader_.ReadUnencoded(&word)) - return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64"; - } else { - return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length"; - } - inst_words_.push_back(static_cast(word)); - inst_words_.push_back(static_cast(word >> 32)); - } - return SPV_SUCCESS; -} - -void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) { - const size_t num_bits_to_next_byte = - GetNumBitsToNextByte(writer_.GetNumBits()); - if (num_bits_to_next_byte == 0 || - num_bits_to_next_byte > byte_break_if_less_than) - return; - - if (logger_) { - logger_->AppendWhitespaces(kCommentNumWhitespaces); - logger_->AppendText(""); - } - - writer_.WriteBits(0, num_bits_to_next_byte); -} - -bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) { - const size_t num_bits_to_next_byte = - GetNumBitsToNextByte(reader_.GetNumReadBits()); - if (num_bits_to_next_byte == 0 || - num_bits_to_next_byte > byte_break_if_less_than) - return true; - - uint64_t bits = 0; - if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false; - - assert(bits == 0); - if (bits != 0) return false; - - return true; -} - -spv_result_t MarkvEncoder::EncodeInstruction( - const spv_parsed_instruction_t& inst) { - SpvOp opcode = SpvOp(inst.opcode); - inst_ = inst; - - const spv_result_t validation_result = UpdateValidationState(inst); - if (validation_result != SPV_SUCCESS) return validation_result; - - LogDisassemblyInstruction(); - - const spv_result_t opcode_encodig_result = - EncodeOpcodeAndNumOperands(opcode, inst.num_operands); - if (opcode_encodig_result < 0) return opcode_encodig_result; - - if (opcode_encodig_result != SPV_SUCCESS) { - // Fallback encoding for opcode and num_operands. - writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length()); - - if (!OpcodeHasFixedNumberOfOperands(opcode)) { - // If the opcode has a variable number of operands, encode the number of - // operands with the instruction. - - if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces); - - writer_.WriteVariableWidthU16(inst.num_operands, - model_->num_operands_chunk_length()); - } - } - - // Write operands. - const uint32_t num_operands = inst_.num_operands; - for (operand_index_ = 0; operand_index_ < num_operands; ++operand_index_) { - operand_ = inst_.operands[operand_index_]; - - if (logger_) { - logger_->AppendWhitespaces(kCommentNumWhitespaces); - logger_->AppendText("<"); - logger_->AppendText(spvOperandTypeStr(operand_.type)); - logger_->AppendText(">"); - } - - switch (operand_.type) { - case SPV_OPERAND_TYPE_RESULT_ID: - case SPV_OPERAND_TYPE_TYPE_ID: - case SPV_OPERAND_TYPE_ID: - case SPV_OPERAND_TYPE_OPTIONAL_ID: - case SPV_OPERAND_TYPE_SCOPE_ID: - case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { - const uint32_t id = inst_.words[operand_.offset]; - if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) { - const spv_result_t result = EncodeTypeId(); - if (result != SPV_SUCCESS) return result; - } else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) { - const spv_result_t result = EncodeResultId(); - if (result != SPV_SUCCESS) return result; - } else { - const spv_result_t result = EncodeRefId(id); - if (result != SPV_SUCCESS) return result; - } - - PromoteIfNeeded(id); - break; - } - - case SPV_OPERAND_TYPE_LITERAL_INTEGER: { - const spv_result_t result = - EncodeNonIdWord(inst_.words[operand_.offset]); - if (result != SPV_SUCCESS) return result; - break; - } - - case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: { - const spv_result_t result = EncodeLiteralNumber(operand_); - if (result != SPV_SUCCESS) return result; - break; - } - - case SPV_OPERAND_TYPE_LITERAL_STRING: { - const char* src = - reinterpret_cast(&inst_.words[operand_.offset]); - - auto* codec = model_->GetLiteralStringHuffmanCodec(opcode); - if (codec) { - uint64_t bits = 0; - size_t num_bits = 0; - const std::string str = src; - if (codec->Encode(str, &bits, &num_bits)) { - writer_.WriteBits(bits, num_bits); - break; - } else { - bool result = - codec->Encode("kMarkvNoneOfTheAbove", &bits, &num_bits); - (void)result; - assert(result); - writer_.WriteBits(bits, num_bits); - } - } - - const size_t length = spv_strnlen_s(src, operand_.num_words * 4); - if (length == operand_.num_words * 4) - return Diag(SPV_ERROR_INVALID_BINARY) - << "Failed to find terminal character of literal string"; - for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]); - break; - } - - default: { - for (int i = 0; i < operand_.num_words; ++i) { - const uint32_t word = inst_.words[operand_.offset + i]; - const spv_result_t result = EncodeNonIdWord(word); - if (result != SPV_SUCCESS) return result; - } - break; - } - } - } - - AddByteBreak(kByteBreakAfterInstIfLessThanUntilNextByte); - - if (logger_) { - logger_->NewLine(); - logger_->NewLine(); - if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION; - } - - ProcessCurInstruction(); - - return SPV_SUCCESS; -} - -spv_result_t MarkvDecoder::DecodeModule(std::vector* spirv_binary) { - const bool header_read_success = - reader_.ReadUnencoded(&header_.magic_number) && - reader_.ReadUnencoded(&header_.markv_version) && - reader_.ReadUnencoded(&header_.markv_model) && - reader_.ReadUnencoded(&header_.markv_length_in_bits) && - reader_.ReadUnencoded(&header_.spirv_version) && - reader_.ReadUnencoded(&header_.spirv_generator); - - if (!header_read_success) - return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header"; - - if (header_.markv_length_in_bits == 0) - return Diag(SPV_ERROR_INVALID_BINARY) - << "Header markv_length_in_bits field is zero"; - - if (header_.magic_number != kMarkvMagicNumber) - return Diag(SPV_ERROR_INVALID_BINARY) - << "MARK-V binary has incorrect magic number"; - - // TODO(atgoo@github.com): Print version strings. - if (header_.markv_version != GetMarkvVersion()) - return Diag(SPV_ERROR_INVALID_BINARY) - << "MARK-V binary and the codec have different versions"; - - const uint32_t model_type = header_.markv_model >> 16; - const uint32_t model_version = header_.markv_model & 0xFFFF; - if (model_type != model_->model_type()) - return Diag(SPV_ERROR_INVALID_BINARY) - << "MARK-V binary and the codec use different MARK-V models"; - - if (model_version != model_->model_version()) - return Diag(SPV_ERROR_INVALID_BINARY) - << "MARK-V binary and the codec use different versions if the same " - << "MARK-V model"; - - spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic. - spirv_.resize(5, 0); - spirv_[0] = kSpirvMagicNumber; - spirv_[1] = header_.spirv_version; - spirv_[2] = header_.spirv_generator; - - if (logger_) { - reader_.SetCallback( - [this](const std::string& str) { logger_->AppendBitSequence(str); }); - } - - while (reader_.GetNumReadBits() < header_.markv_length_in_bits) { - inst_ = {}; - const spv_result_t decode_result = DecodeInstruction(); - if (decode_result != SPV_SUCCESS) return decode_result; - - const spv_result_t validation_result = UpdateValidationState(inst_); - if (validation_result != SPV_SUCCESS) return validation_result; - } - - if (reader_.GetNumReadBits() != header_.markv_length_in_bits || - !reader_.OnlyZeroesLeft()) { - return Diag(SPV_ERROR_INVALID_BINARY) - << "MARK-V binary has wrong stated bit length " - << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits; - } - - // Decoding of the module is finished, validation state should have correct - // id bound. - spirv_[3] = GetIdBound(); - - *spirv_binary = std::move(spirv_); - return SPV_SUCCESS; -} - -// TODO(atgoo@github.com): The implementation borrows heavily from -// Parser::parseOperand. -// Consider coupling them together in some way once MARK-V codec is more mature. -// For now it's better to keep the code independent for experimentation -// purposes. -spv_result_t MarkvDecoder::DecodeOperand( - size_t operand_offset, const spv_operand_type_t type, - spv_operand_pattern_t* expected_operands) { - const SpvOp opcode = static_cast(inst_.opcode); - - memset(&operand_, 0, sizeof(operand_)); - - assert((operand_offset >> 16) == 0); - operand_.offset = static_cast(operand_offset); - operand_.type = type; - - // Set default values, may be updated later. - operand_.number_kind = SPV_NUMBER_NONE; - operand_.number_bit_width = 0; - - const size_t first_word_index = inst_words_.size(); - - switch (type) { - case SPV_OPERAND_TYPE_RESULT_ID: { - const spv_result_t result = DecodeResultId(); - if (result != SPV_SUCCESS) return result; - - inst_words_.push_back(inst_.result_id); - SetIdBound(std::max(GetIdBound(), inst_.result_id + 1)); - PromoteIfNeeded(inst_.result_id); - break; - } - - case SPV_OPERAND_TYPE_TYPE_ID: { - const spv_result_t result = DecodeTypeId(); - if (result != SPV_SUCCESS) return result; - - inst_words_.push_back(inst_.type_id); - SetIdBound(std::max(GetIdBound(), inst_.type_id + 1)); - PromoteIfNeeded(inst_.type_id); - break; - } - - case SPV_OPERAND_TYPE_ID: - case SPV_OPERAND_TYPE_OPTIONAL_ID: - case SPV_OPERAND_TYPE_SCOPE_ID: - case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { - uint32_t id = 0; - const spv_result_t result = DecodeRefId(&id); - if (result != SPV_SUCCESS) return result; - - if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0"; - - if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) { - operand_.type = SPV_OPERAND_TYPE_ID; - - if (opcode == SpvOpExtInst && operand_.offset == 3) { - // The current word is the extended instruction set id. - // Set the extended instruction set type for the current - // instruction. - auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id); - if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) { - return Diag(SPV_ERROR_INVALID_ID) - << "OpExtInst set id " << id - << " does not reference an OpExtInstImport result Id"; - } - inst_.ext_inst_type = ext_inst_type_iter->second; - } - } - - inst_words_.push_back(id); - SetIdBound(std::max(GetIdBound(), id + 1)); - PromoteIfNeeded(id); - break; - } - - case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: { - uint32_t word = 0; - const spv_result_t result = DecodeNonIdWord(&word); - if (result != SPV_SUCCESS) return result; - - inst_words_.push_back(word); - - assert(SpvOpExtInst == opcode); - assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE); - spv_ext_inst_desc ext_inst; - if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst)) - return Diag(SPV_ERROR_INVALID_BINARY) - << "Invalid extended instruction number: " << word; - spvPushOperandTypes(ext_inst->operandTypes, expected_operands); - break; - } - - case SPV_OPERAND_TYPE_LITERAL_INTEGER: - case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: { - // These are regular single-word literal integer operands. - // Post-parsing validation should check the range of the parsed value. - operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER; - // It turns out they are always unsigned integers! - operand_.number_kind = SPV_NUMBER_UNSIGNED_INT; - operand_.number_bit_width = 32; - - uint32_t word = 0; - const spv_result_t result = DecodeNonIdWord(&word); - if (result != SPV_SUCCESS) return result; - - inst_words_.push_back(word); - break; - } - - case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: - case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: { - operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER; - if (opcode == SpvOpSwitch) { - // The literal operands have the same type as the value - // referenced by the selector Id. - const uint32_t selector_id = inst_words_.at(1); - const auto type_id_iter = id_to_type_id_.find(selector_id); - if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) { - return Diag(SPV_ERROR_INVALID_BINARY) - << "Invalid OpSwitch: selector id " << selector_id - << " has no type"; - } - uint32_t type_id = type_id_iter->second; - - if (selector_id == type_id) { - // Recall that by convention, a result ID that is a type definition - // maps to itself. - return Diag(SPV_ERROR_INVALID_BINARY) - << "Invalid OpSwitch: selector id " << selector_id - << " is a type, not a value"; - } - if (auto error = SetNumericTypeInfoForType(&operand_, type_id)) - return error; - if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT && - operand_.number_kind != SPV_NUMBER_SIGNED_INT) { - return Diag(SPV_ERROR_INVALID_BINARY) - << "Invalid OpSwitch: selector id " << selector_id - << " is not a scalar integer"; - } - } else { - assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant); - // The literal number type is determined by the type Id for the - // constant. - assert(inst_.type_id); - if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id)) - return error; - } - - if (auto error = DecodeLiteralNumber(operand_)) return error; - - break; - } - - case SPV_OPERAND_TYPE_LITERAL_STRING: - case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: { - operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING; - std::vector str; - auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode); - - if (codec) { - std::string decoded_string; - const bool huffman_result = - codec->DecodeFromStream(GetReadBitCallback(), &decoded_string); - assert(huffman_result); - if (!huffman_result) - return Diag(SPV_ERROR_INVALID_BINARY) - << "Failed to read literal string"; - - if (decoded_string != "kMarkvNoneOfTheAbove") { - std::copy(decoded_string.begin(), decoded_string.end(), - std::back_inserter(str)); - str.push_back('\0'); - } - } - - // The loop is expected to terminate once we encounter '\0' or exhaust - // the bit stream. - if (str.empty()) { - while (true) { - char ch = 0; - if (!reader_.ReadUnencoded(&ch)) - return Diag(SPV_ERROR_INVALID_BINARY) - << "Failed to read literal string"; - - str.push_back(ch); - - if (ch == '\0') break; - } - } - - while (str.size() % 4 != 0) str.push_back('\0'); - - inst_words_.resize(inst_words_.size() + str.size() / 4); - std::memcpy(&inst_words_[first_word_index], str.data(), str.size()); - - if (SpvOpExtInstImport == opcode) { - // Record the extended instruction type for the ID for this import. - // There is only one string literal argument to OpExtInstImport, - // so it's sufficient to guard this just on the opcode. - const spv_ext_inst_type_t ext_inst_type = - spvExtInstImportTypeGet(str.data()); - if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) { - return Diag(SPV_ERROR_INVALID_BINARY) - << "Invalid extended instruction import '" << str.data() - << "'"; - } - // We must have parsed a valid result ID. It's a condition - // of the grammar, and we only accept non-zero result Ids. - assert(inst_.result_id); - const bool inserted = - import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type) - .second; - (void)inserted; - assert(inserted); - } - break; - } - - case SPV_OPERAND_TYPE_CAPABILITY: - case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: - case SPV_OPERAND_TYPE_EXECUTION_MODEL: - case SPV_OPERAND_TYPE_ADDRESSING_MODEL: - case SPV_OPERAND_TYPE_MEMORY_MODEL: - case SPV_OPERAND_TYPE_EXECUTION_MODE: - case SPV_OPERAND_TYPE_STORAGE_CLASS: - case SPV_OPERAND_TYPE_DIMENSIONALITY: - case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: - case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: - case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: - case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: - case SPV_OPERAND_TYPE_LINKAGE_TYPE: - case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: - case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: - case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: - case SPV_OPERAND_TYPE_DECORATION: - case SPV_OPERAND_TYPE_BUILT_IN: - case SPV_OPERAND_TYPE_GROUP_OPERATION: - case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: - case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: { - // A single word that is a plain enum value. - uint32_t word = 0; - const spv_result_t result = DecodeNonIdWord(&word); - if (result != SPV_SUCCESS) return result; - - inst_words_.push_back(word); - - // Map an optional operand type to its corresponding concrete type. - if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER) - operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER; - - spv_operand_desc entry; - if (grammar_.lookupOperand(type, word, &entry)) { - return Diag(SPV_ERROR_INVALID_BINARY) - << "Invalid " << spvOperandTypeStr(operand_.type) - << " operand: " << word; - } - - // Prepare to accept operands to this operand, if needed. - spvPushOperandTypes(entry->operandTypes, expected_operands); - break; - } - - case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: - case SPV_OPERAND_TYPE_FUNCTION_CONTROL: - case SPV_OPERAND_TYPE_LOOP_CONTROL: - case SPV_OPERAND_TYPE_IMAGE: - case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: - case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: - case SPV_OPERAND_TYPE_SELECTION_CONTROL: { - // This operand is a mask. - uint32_t word = 0; - const spv_result_t result = DecodeNonIdWord(&word); - if (result != SPV_SUCCESS) return result; - - inst_words_.push_back(word); - - // Map an optional operand type to its corresponding concrete type. - if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE) - operand_.type = SPV_OPERAND_TYPE_IMAGE; - else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS) - operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS; - - // Check validity of set mask bits. Also prepare for operands for those - // masks if they have any. To get operand order correct, scan from - // MSB to LSB since we can only prepend operands to a pattern. - // The only case in the grammar where you have more than one mask bit - // having an operand is for image operands. See SPIR-V 3.14 Image - // Operands. - uint32_t remaining_word = word; - for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) { - if (remaining_word & mask) { - spv_operand_desc entry; - if (grammar_.lookupOperand(type, mask, &entry)) { - return Diag(SPV_ERROR_INVALID_BINARY) - << "Invalid " << spvOperandTypeStr(operand_.type) - << " operand: " << word << " has invalid mask component " - << mask; - } - remaining_word ^= mask; - spvPushOperandTypes(entry->operandTypes, expected_operands); - } - } - if (word == 0) { - // An all-zeroes mask *might* also be valid. - spv_operand_desc entry; - if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) { - // Prepare for its operands, if any. - spvPushOperandTypes(entry->operandTypes, expected_operands); - } - } - break; - } - default: - return Diag(SPV_ERROR_INVALID_BINARY) - << "Internal error: Unhandled operand type: " << type; - } - - operand_.num_words = uint16_t(inst_words_.size() - first_word_index); - - assert(spvOperandIsConcrete(operand_.type)); - - parsed_operands_.push_back(operand_); - - return SPV_SUCCESS; -} - -spv_result_t MarkvDecoder::DecodeInstruction() { - parsed_operands_.clear(); - inst_words_.clear(); - - // Opcode/num_words placeholder, the word will be filled in later. - inst_words_.push_back(0); - - bool num_operands_still_unknown = true; - { - uint32_t opcode = 0; - uint32_t num_operands = 0; - - const spv_result_t opcode_decoding_result = - DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands); - if (opcode_decoding_result < 0) return opcode_decoding_result; - - if (opcode_decoding_result == SPV_SUCCESS) { - inst_.num_operands = static_cast(num_operands); - num_operands_still_unknown = false; - } else { - if (!reader_.ReadVariableWidthU32(&opcode, - model_->opcode_chunk_length())) { - return Diag(SPV_ERROR_INVALID_BINARY) - << "Failed to read opcode of instruction"; - } - } - - inst_.opcode = static_cast(opcode); - } - - const SpvOp opcode = static_cast(inst_.opcode); - - spv_opcode_desc opcode_desc; - if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) { - return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode"; - } - - spv_operand_pattern_t expected_operands; - expected_operands.reserve(opcode_desc->numTypes); - for (auto i = 0; i < opcode_desc->numTypes; i++) { - expected_operands.push_back( - opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]); - } - - if (num_operands_still_unknown) { - if (!OpcodeHasFixedNumberOfOperands(opcode)) { - if (!reader_.ReadVariableWidthU16(&inst_.num_operands, - model_->num_operands_chunk_length())) - return Diag(SPV_ERROR_INVALID_BINARY) - << "Failed to read num_operands of instruction"; - } else { - inst_.num_operands = static_cast(expected_operands.size()); - } - } - - for (operand_index_ = 0; - operand_index_ < static_cast(inst_.num_operands); - ++operand_index_) { - assert(!expected_operands.empty()); - const spv_operand_type_t type = - spvTakeFirstMatchableOperand(&expected_operands); - - const size_t operand_offset = inst_words_.size(); - - const spv_result_t decode_result = - DecodeOperand(operand_offset, type, &expected_operands); - - if (decode_result != SPV_SUCCESS) return decode_result; - } - - assert(inst_.num_operands == parsed_operands_.size()); - - // Only valid while inst_words_ and parsed_operands_ remain unchanged (until - // next DecodeInstruction call). - inst_.words = inst_words_.data(); - inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data(); - inst_.num_words = static_cast(inst_words_.size()); - inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode)); - - std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_)); - - assert(inst_.num_words == - std::accumulate( - parsed_operands_.begin(), parsed_operands_.end(), 1, - [](int num_words, const spv_parsed_operand_t& operand) { - return num_words += operand.num_words; - }) && - "num_words in instruction doesn't correspond to the sum of num_words" - "in the operands"); - - RecordNumberType(); - ProcessCurInstruction(); - - if (!ReadToByteBreak(kByteBreakAfterInstIfLessThanUntilNextByte)) - return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break"; - - if (logger_) { - logger_->NewLine(); - std::stringstream ss; - ss << spvOpcodeString(opcode) << " "; - for (size_t index = 1; index < inst_words_.size(); ++index) - ss << inst_words_[index] << " "; - logger_->AppendText(ss.str()); - logger_->NewLine(); - logger_->NewLine(); - if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION; - } - - return SPV_SUCCESS; -} - -spv_result_t MarkvDecoder::SetNumericTypeInfoForType( - spv_parsed_operand_t* parsed_operand, uint32_t type_id) { - assert(type_id != 0); - auto type_info_iter = type_id_to_number_type_info_.find(type_id); - if (type_info_iter == type_id_to_number_type_info_.end()) { - return Diag(SPV_ERROR_INVALID_BINARY) - << "Type Id " << type_id << " is not a type"; - } - - const NumberType& info = type_info_iter->second; - if (info.type == SPV_NUMBER_NONE) { - // This is a valid type, but for something other than a scalar number. - return Diag(SPV_ERROR_INVALID_BINARY) - << "Type Id " << type_id << " is not a scalar numeric type"; - } - - parsed_operand->number_kind = info.type; - parsed_operand->number_bit_width = info.bit_width; - // Round up the word count. - parsed_operand->num_words = static_cast((info.bit_width + 31) / 32); - return SPV_SUCCESS; -} - -void MarkvDecoder::RecordNumberType() { - const SpvOp opcode = static_cast(inst_.opcode); - if (spvOpcodeGeneratesType(opcode)) { - NumberType info = {SPV_NUMBER_NONE, 0}; - if (SpvOpTypeInt == opcode) { - info.bit_width = inst_.words[inst_.operands[1].offset]; - info.type = inst_.words[inst_.operands[2].offset] - ? SPV_NUMBER_SIGNED_INT - : SPV_NUMBER_UNSIGNED_INT; - } else if (SpvOpTypeFloat == opcode) { - info.bit_width = inst_.words[inst_.operands[1].offset]; - info.type = SPV_NUMBER_FLOATING; - } - // The *result* Id of a type generating instruction is the type Id. - type_id_to_number_type_info_[inst_.result_id] = info; - } -} - -spv_result_t EncodeHeader(void* user_data, spv_endianness_t endian, - uint32_t magic, uint32_t version, uint32_t generator, - uint32_t id_bound, uint32_t schema) { - MarkvEncoder* encoder = reinterpret_cast(user_data); - return encoder->EncodeHeader(endian, magic, version, generator, id_bound, - schema); -} - -spv_result_t EncodeInstruction(void* user_data, - const spv_parsed_instruction_t* inst) { - MarkvEncoder* encoder = reinterpret_cast(user_data); - return encoder->EncodeInstruction(*inst); -} - -} // namespace - -spv_result_t SpirvToMarkv( - spv_const_context context, const std::vector& spirv, - const MarkvCodecOptions& options, const MarkvModel& markv_model, - MessageConsumer message_consumer, MarkvLogConsumer log_consumer, - MarkvDebugConsumer debug_consumer, std::vector* markv) { - spv_context_t hijack_context = *context; - libspirv::SetContextMessageConsumer(&hijack_context, message_consumer); - - spv_const_binary_t spirv_binary = {spirv.data(), spirv.size()}; - - spv_endianness_t endian; - spv_position_t position = {}; - if (spvBinaryEndianness(&spirv_binary, &endian)) { - return DiagnosticStream(position, hijack_context.consumer, - SPV_ERROR_INVALID_BINARY) - << "Invalid SPIR-V magic number."; - } - - spv_header_t header; - if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) { - return DiagnosticStream(position, hijack_context.consumer, - SPV_ERROR_INVALID_BINARY) - << "Invalid SPIR-V header."; - } - - MarkvEncoder encoder(&hijack_context, options, &markv_model); - - if (log_consumer || debug_consumer) { - encoder.CreateLogger(log_consumer, debug_consumer); - - spv_text text = nullptr; - if (spvBinaryToText(&hijack_context, spirv.data(), spirv.size(), - SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, - nullptr) != SPV_SUCCESS) { - return DiagnosticStream(position, hijack_context.consumer, - SPV_ERROR_INVALID_BINARY) - << "Failed to disassemble SPIR-V binary."; - } - assert(text); - encoder.SetDisassembly(std::string(text->str, text->length)); - spvTextDestroy(text); - } - - if (spvBinaryParse(&hijack_context, &encoder, spirv.data(), spirv.size(), - EncodeHeader, EncodeInstruction, nullptr) != SPV_SUCCESS) { - return DiagnosticStream(position, hijack_context.consumer, - SPV_ERROR_INVALID_BINARY) - << "Unable to encode to MARK-V."; - } - - *markv = encoder.GetMarkvBinary(); - return SPV_SUCCESS; -} - -spv_result_t MarkvToSpirv( - spv_const_context context, const std::vector& markv, - const MarkvCodecOptions& options, const MarkvModel& markv_model, - MessageConsumer message_consumer, MarkvLogConsumer log_consumer, - MarkvDebugConsumer debug_consumer, std::vector* spirv) { - spv_position_t position = {}; - spv_context_t hijack_context = *context; - libspirv::SetContextMessageConsumer(&hijack_context, message_consumer); - - MarkvDecoder decoder(&hijack_context, markv, options, &markv_model); - - if (log_consumer || debug_consumer) - decoder.CreateLogger(log_consumer, debug_consumer); - - if (decoder.DecodeModule(spirv) != SPV_SUCCESS) { - return DiagnosticStream(position, hijack_context.consumer, - SPV_ERROR_INVALID_BINARY) - << "Unable to decode MARK-V."; - } - - assert(!spirv->empty()); - return SPV_SUCCESS; -} - +} // namespace comp } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/comp/markv_codec.h b/3rdparty/spirv-tools/source/comp/markv_codec.h new file mode 100644 index 000000000..f313d6178 --- /dev/null +++ b/3rdparty/spirv-tools/source/comp/markv_codec.h @@ -0,0 +1,337 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_COMP_MARKV_CODEC_H_ +#define SOURCE_COMP_MARKV_CODEC_H_ + +#include +#include +#include +#include + +#include "source/assembly_grammar.h" +#include "source/comp/huffman_codec.h" +#include "source/comp/markv_model.h" +#include "source/comp/move_to_front.h" +#include "source/diagnostic.h" +#include "source/id_descriptor.h" + +#include "source/val/instruction.h" + +// Base class for MARK-V encoder and decoder. Contains common functionality +// such as: +// - Validator connection and validation state. +// - SPIR-V grammar and helper functions. + +namespace spvtools { +namespace comp { + +class MarkvLogger; + +// Handles for move-to-front sequences. Enums which end with "Begin" define +// handle spaces which start at that value and span 16 or 32 bit wide. +enum : uint64_t { + kMtfNone = 0, + // All ids. + kMtfAll, + // All forward declared ids. + kMtfForwardDeclared, + // All type ids except for generated by OpTypeFunction. + kMtfTypeNonFunction, + // All labels. + kMtfLabel, + // All ids created by instructions which had type_id. + kMtfObject, + // All types generated by OpTypeFloat, OpTypeInt, OpTypeBool. + kMtfTypeScalar, + // All composite types. + kMtfTypeComposite, + // Boolean type or any vector type of it. + kMtfTypeBoolScalarOrVector, + // All float types or any vector floats type. + kMtfTypeFloatScalarOrVector, + // All int types or any vector int type. + kMtfTypeIntScalarOrVector, + // All types declared as return types in OpTypeFunction. + kMtfTypeReturnedByFunction, + // All composite objects. + kMtfComposite, + // All bool objects or vectors of bools. + kMtfBoolScalarOrVector, + // All float objects or vectors of float. + kMtfFloatScalarOrVector, + // All int objects or vectors of int. + kMtfIntScalarOrVector, + // All pointer types which point to composited. + kMtfTypePointerToComposite, + // Used by EncodeMtfRankHuffman. + kMtfGenericNonZeroRank, + // Handle space for ids of specific type. + kMtfIdOfTypeBegin = 0x10000, + // Handle space for ids generated by specific opcode. + kMtfIdGeneratedByOpcode = 0x20000, + // Handle space for ids of objects with type generated by specific opcode. + kMtfIdWithTypeGeneratedByOpcodeBegin = 0x30000, + // All vectors of specific component type. + kMtfVectorOfComponentTypeBegin = 0x40000, + // All vector types of specific size. + kMtfTypeVectorOfSizeBegin = 0x50000, + // All pointer types to specific type. + kMtfPointerToTypeBegin = 0x60000, + // All function types which return specific type. + kMtfFunctionTypeWithReturnTypeBegin = 0x70000, + // All function objects which return specific type. + kMtfFunctionWithReturnTypeBegin = 0x80000, + // Short id descriptor space (max 16-bit). + kMtfShortIdDescriptorSpaceBegin = 0x90000, + // Long id descriptor space (32-bit). + kMtfLongIdDescriptorSpaceBegin = 0x100000000, +}; + +class MarkvCodec { + public: + static const uint32_t kMarkvMagicNumber; + + // Mtf ranks smaller than this are encoded with Huffman coding. + static const uint32_t kMtfSmallestRankEncodedByValue; + + // Signals that the mtf rank is too large to be encoded with Huffman. + static const uint32_t kMtfRankEncodedByValueSignal; + + static const uint32_t kShortDescriptorNumBits; + + static const size_t kByteBreakAfterInstIfLessThanUntilNextByte; + + static uint32_t GetMarkvVersion(); + + virtual ~MarkvCodec(); + + protected: + struct MarkvHeader { + MarkvHeader(); + + uint32_t magic_number; + uint32_t markv_version; + // Magic number to identify or verify MarkvModel used for encoding. + uint32_t markv_model = 0; + uint32_t markv_length_in_bits = 0; + uint32_t spirv_version = 0; + uint32_t spirv_generator = 0; + }; + + // |model| is owned by the caller, must be not null and valid during the + // lifetime of the codec. + MarkvCodec(spv_const_context context, spv_validator_options validator_options, + const MarkvModel* model); + + // Returns instruction which created |id| or nullptr if such instruction was + // not registered. + const val::Instruction* FindDef(uint32_t id) const { + const auto it = id_to_def_instruction_.find(id); + if (it == id_to_def_instruction_.end()) return nullptr; + return it->second; + } + + size_t GetNumBitsToNextByte(size_t bit_pos) const; + bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) const; + + // Returns type id of vector type component. + uint32_t GetVectorComponentType(uint32_t vector_type_id) const { + const val::Instruction* type_inst = FindDef(vector_type_id); + assert(type_inst); + assert(type_inst->opcode() == SpvOpTypeVector); + + const uint32_t component_type = + type_inst->word(type_inst->operands()[1].offset); + return component_type; + } + + // Returns mtf handle for ids of given type. + uint64_t GetMtfIdOfType(uint32_t type_id) const { + return kMtfIdOfTypeBegin + type_id; + } + + // Returns mtf handle for ids generated by given opcode. + uint64_t GetMtfIdGeneratedByOpcode(SpvOp opcode) const { + return kMtfIdGeneratedByOpcode + opcode; + } + + // Returns mtf handle for ids of type generated by given opcode. + uint64_t GetMtfIdWithTypeGeneratedByOpcode(SpvOp opcode) const { + return kMtfIdWithTypeGeneratedByOpcodeBegin + opcode; + } + + // Returns mtf handle for vectors of specific component type. + uint64_t GetMtfVectorOfComponentType(uint32_t type_id) const { + return kMtfVectorOfComponentTypeBegin + type_id; + } + + // Returns mtf handle for vector type of specific size. + uint64_t GetMtfTypeVectorOfSize(uint32_t size) const { + return kMtfTypeVectorOfSizeBegin + size; + } + + // Returns mtf handle for pointers to specific size. + uint64_t GetMtfPointerToType(uint32_t type_id) const { + return kMtfPointerToTypeBegin + type_id; + } + + // Returns mtf handle for function types with given return type. + uint64_t GetMtfFunctionTypeWithReturnType(uint32_t type_id) const { + return kMtfFunctionTypeWithReturnTypeBegin + type_id; + } + + // Returns mtf handle for functions with given return type. + uint64_t GetMtfFunctionWithReturnType(uint32_t type_id) const { + return kMtfFunctionWithReturnTypeBegin + type_id; + } + + // Returns mtf handle for the given long id descriptor. + uint64_t GetMtfLongIdDescriptor(uint32_t descriptor) const { + return kMtfLongIdDescriptorSpaceBegin + descriptor; + } + + // Returns mtf handle for the given short id descriptor. + uint64_t GetMtfShortIdDescriptor(uint32_t descriptor) const { + return kMtfShortIdDescriptorSpaceBegin + descriptor; + } + + // Process data from the current instruction. This would update MTFs and + // other data containers. + void ProcessCurInstruction(); + + // Returns move-to-front handle to be used for the current operand slot. + // Mtf handle is chosen based on a set of rules defined by SPIR-V grammar. + uint64_t GetRuleBasedMtf(); + + // Returns words of the current instruction. Decoder has a different + // implementation and the array is valid only until the previously decoded + // word. + virtual const uint32_t* GetInstWords() const { return inst_.words; } + + // Returns the opcode of the previous instruction. + SpvOp GetPrevOpcode() const { + if (instructions_.empty()) return SpvOpNop; + + return instructions_.back()->opcode(); + } + + // Returns diagnostic stream, position index is set to instruction number. + DiagnosticStream Diag(spv_result_t error_code) const { + return DiagnosticStream({0, 0, instructions_.size()}, context_->consumer, + "", error_code); + } + + // Returns current id bound. + uint32_t GetIdBound() const { return id_bound_; } + + // Sets current id bound, expected to be no lower than the previous one. + void SetIdBound(uint32_t id_bound) { + assert(id_bound >= id_bound_); + id_bound_ = id_bound; + } + + // Returns Huffman codec for ranks of the mtf with given |handle|. + // Different mtfs can use different rank distributions. + // May return nullptr if the codec doesn't exist. + const HuffmanCodec* GetMtfHuffmanCodec(uint64_t handle) const { + const auto it = mtf_huffman_codecs_.find(handle); + if (it == mtf_huffman_codecs_.end()) return nullptr; + return it->second.get(); + } + + // Promotes id in all move-to-front sequences if ids can be shared by multiple + // sequences. + void PromoteIfNeeded(uint32_t id) { + if (!model_->AnyDescriptorHasCodingScheme() && + model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kShortDescriptor) { + // Move-to-front sequences do not share ids. Nothing to do. + return; + } + multi_mtf_.Promote(id); + } + + spv_validator_options validator_options_ = nullptr; + const AssemblyGrammar grammar_; + MarkvHeader header_; + + // MARK-V model, not owned. + const MarkvModel* model_ = nullptr; + + // Current instruction, current operand and current operand index. + spv_parsed_instruction_t inst_; + spv_parsed_operand_t operand_; + uint32_t operand_index_; + + // Maps a result ID to its type ID. By convention: + // - a result ID that is a type definition maps to itself. + // - a result ID without a type maps to 0. (E.g. for OpLabel) + std::unordered_map id_to_type_id_; + + // Container for all move-to-front sequences. + MultiMoveToFront multi_mtf_; + + // Id of the current function or zero if outside of function. + uint32_t cur_function_id_ = 0; + + // Return type of the current function. + uint32_t cur_function_return_type_ = 0; + + // Remaining function parameter types. This container is filled on OpFunction, + // and drained on OpFunctionParameter. + std::list remaining_function_parameter_types_; + + // List of ids local to the current function. + std::vector ids_local_to_cur_function_; + + // List of instructions in the order they are given in the module. + std::vector> instructions_; + + // Container/computer for long (32-bit) id descriptors. + IdDescriptorCollection long_id_descriptors_; + + // Container/computer for short id descriptors. + // Short descriptors are stored in uint32_t, but their actual bit width is + // defined with kShortDescriptorNumBits. + // It doesn't seem logical to have a different computer for short id + // descriptors, since one could actually map/truncate long descriptors. + // But as short descriptors have collisions, the efficiency of + // compression depends on the collision pattern, and short descriptors + // produced by function ShortHashU32Array have been empirically proven to + // produce better results. + IdDescriptorCollection short_id_descriptors_; + + // Huffman codecs for move-to-front ranks. The map key is mtf handle. Doesn't + // need to contain a different codec for every handle as most use one and the + // same. + std::map>> + mtf_huffman_codecs_; + + // If not nullptr, codec will log comments on the compression process. + std::unique_ptr logger_; + + spv_const_context context_ = nullptr; + + private: + // Maps result id to the instruction which defined it. + std::unordered_map id_to_def_instruction_; + + uint32_t id_bound_ = 1; +}; + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_MARKV_CODEC_H_ diff --git a/3rdparty/spirv-tools/source/comp/markv_decoder.cpp b/3rdparty/spirv-tools/source/comp/markv_decoder.cpp new file mode 100644 index 000000000..22115831d --- /dev/null +++ b/3rdparty/spirv-tools/source/comp/markv_decoder.cpp @@ -0,0 +1,925 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/comp/markv_decoder.h" + +#include +#include +#include + +#include "source/ext_inst.h" +#include "source/opcode.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace comp { + +spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) { + auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_); + + if (codec) { + uint64_t decoded_value = 0; + if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to decode non-id word with Huffman"; + + if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) { + // The word decoded successfully. + *word = uint32_t(decoded_value); + assert(*word == decoded_value); + return SPV_SUCCESS; + } + + // Received kMarkvNoneOfTheAbove signal, use fallback decoding. + } + + const size_t chunk_length = + model_->GetOperandVariableWidthChunkLength(operand_.type); + if (chunk_length) { + if (!reader_.ReadVariableWidthU32(word, chunk_length)) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to decode non-id word with varint"; + } else { + if (!reader_.ReadUnencoded(word)) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read unencoded non-id word"; + } + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands( + uint32_t* opcode, uint32_t* num_operands) { + // First try to use the Markov chain codec. + auto* codec = + model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode()); + if (codec) { + uint64_t decoded_value = 0; + if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to decode opcode_and_num_operands, previous opcode is " + << spvOpcodeString(GetPrevOpcode()); + + if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) { + // The word was successfully decoded. + *opcode = uint32_t(decoded_value & 0xFFFF); + *num_operands = uint32_t(decoded_value >> 16); + return SPV_SUCCESS; + } + + // Received kMarkvNoneOfTheAbove signal, use fallback decoding. + } + + // Fallback to base-rate codec. + codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop); + assert(codec); + uint64_t decoded_value = 0; + if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to decode opcode_and_num_operands with global codec"; + + if (decoded_value == MarkvModel::GetMarkvNoneOfTheAbove()) { + // Received kMarkvNoneOfTheAbove signal, fallback further. + return SPV_UNSUPPORTED; + } + + *opcode = uint32_t(decoded_value & 0xFFFF); + *num_operands = uint32_t(decoded_value >> 16); + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf, + uint32_t fallback_method, + uint32_t* rank) { + const auto* codec = GetMtfHuffmanCodec(mtf); + if (!codec) { + assert(fallback_method != kMtfNone); + codec = GetMtfHuffmanCodec(fallback_method); + } + + if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank"; + + uint32_t decoded_value = 0; + if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) + return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman"; + + if (decoded_value == kMtfRankEncodedByValueSignal) { + // Decode by value. + if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length())) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to decode MTF rank with varint"; + *rank += MarkvCodec::kMtfSmallestRankEncodedByValue; + } else { + // Decode using Huffman coding. + assert(decoded_value < MarkvCodec::kMtfSmallestRankEncodedByValue); + *rank = decoded_value; + } + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) { + auto* codec = + model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_); + + uint64_t mtf = kMtfNone; + if (codec) { + uint64_t decoded_value = 0; + if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to decode descriptor with Huffman"; + + if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) { + const uint32_t long_descriptor = uint32_t(decoded_value); + mtf = GetMtfLongIdDescriptor(long_descriptor); + } + } + + if (mtf == kMtfNone) { + if (model_->id_fallback_strategy() != + MarkvModel::IdFallbackStrategy::kShortDescriptor) { + return SPV_UNSUPPORTED; + } + + uint64_t decoded_value = 0; + if (!reader_.ReadBits(&decoded_value, MarkvCodec::kShortDescriptorNumBits)) + return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor"; + const uint32_t short_descriptor = uint32_t(decoded_value); + if (short_descriptor == 0) { + // Forward declared id. + return SPV_UNSUPPORTED; + } + mtf = GetMtfShortIdDescriptor(short_descriptor); + } + + return DecodeExistingId(mtf, id); +} + +spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) { + assert(multi_mtf_.GetSize(mtf) > 0); + *id = 0; + + uint32_t rank = 0; + + if (multi_mtf_.GetSize(mtf) == 1) { + rank = 1; + } else { + const spv_result_t result = + DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank); + if (result != SPV_SUCCESS) return result; + } + + assert(rank); + if (!multi_mtf_.ValueFromRank(mtf, rank, id)) + return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds"; + + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) { + { + const spv_result_t result = DecodeIdWithDescriptor(id); + if (result != SPV_UNSUPPORTED) return result; + } + + const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction( + SpvOp(inst_.opcode))(operand_index_); + uint32_t rank = 0; + *id = 0; + + if (model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased) { + uint64_t mtf = GetRuleBasedMtf(); + if (mtf != kMtfNone && !can_forward_declare) { + return DecodeExistingId(mtf, id); + } + + if (mtf == kMtfNone) mtf = kMtfAll; + { + const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank); + if (result != SPV_SUCCESS) return result; + } + + if (rank == 0) { + // This is the first occurrence of a forward declared id. + *id = GetIdBound(); + SetIdBound(*id + 1); + multi_mtf_.Insert(kMtfAll, *id); + multi_mtf_.Insert(kMtfForwardDeclared, *id); + if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id); + } else { + if (!multi_mtf_.ValueFromRank(mtf, rank, id)) + return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds"; + } + } else { + assert(can_forward_declare); + + if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length())) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to decode MTF rank with varint"; + + if (rank == 0) { + // This is the first occurrence of a forward declared id. + *id = GetIdBound(); + SetIdBound(*id + 1); + multi_mtf_.Insert(kMtfForwardDeclared, *id); + } else { + if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id)) + return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds"; + } + } + assert(*id); + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeTypeId() { + if (inst_.opcode == SpvOpFunctionParameter) { + assert(!remaining_function_parameter_types_.empty()); + inst_.type_id = remaining_function_parameter_types_.front(); + remaining_function_parameter_types_.pop_front(); + return SPV_SUCCESS; + } + + { + const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id); + if (result != SPV_UNSUPPORTED) return result; + } + + assert(model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased); + + uint64_t mtf = GetRuleBasedMtf(); + assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))( + operand_index_)); + + if (mtf == kMtfNone) { + mtf = kMtfTypeNonFunction; + // Function types should have been handled by GetRuleBasedMtf. + assert(inst_.opcode != SpvOpFunction); + } + + return DecodeExistingId(mtf, &inst_.type_id); +} + +spv_result_t MarkvDecoder::DecodeResultId() { + uint32_t rank = 0; + + const uint64_t num_still_forward_declared = + multi_mtf_.GetSize(kMtfForwardDeclared); + + if (num_still_forward_declared) { + // Some ids were forward declared. Check if this id is one of them. + uint64_t id_was_forward_declared; + if (!reader_.ReadBits(&id_was_forward_declared, 1)) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read id_was_forward_declared flag"; + + if (id_was_forward_declared) { + if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length())) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read MTF rank of forward declared id"; + + if (rank) { + // The id was forward declared, recover it from kMtfForwardDeclared. + if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, + &inst_.result_id)) + return Diag(SPV_ERROR_INTERNAL) + << "Forward declared MTF rank is out of bounds"; + + // We can now remove the id from kMtfForwardDeclared. + if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to remove id from kMtfForwardDeclared"; + } + } + } + + if (inst_.result_id == 0) { + // The id was not forward declared, issue a new id. + inst_.result_id = GetIdBound(); + SetIdBound(inst_.result_id + 1); + } + + if (model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased) { + if (!rank) { + multi_mtf_.Insert(kMtfAll, inst_.result_id); + } + } + + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeLiteralNumber( + const spv_parsed_operand_t& operand) { + if (operand.number_bit_width <= 32) { + uint32_t word = 0; + const spv_result_t result = DecodeNonIdWord(&word); + if (result != SPV_SUCCESS) return result; + inst_words_.push_back(word); + } else { + assert(operand.number_bit_width <= 64); + uint64_t word = 0; + if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { + if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length())) + return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64"; + } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { + int64_t val = 0; + if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(), + model_->s64_block_exponent())) + return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64"; + std::memcpy(&word, &val, 8); + } else if (operand.number_kind == SPV_NUMBER_FLOATING) { + if (!reader_.ReadUnencoded(&word)) + return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64"; + } else { + return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length"; + } + inst_words_.push_back(static_cast(word)); + inst_words_.push_back(static_cast(word >> 32)); + } + return SPV_SUCCESS; +} + +bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) { + const size_t num_bits_to_next_byte = + GetNumBitsToNextByte(reader_.GetNumReadBits()); + if (num_bits_to_next_byte == 0 || + num_bits_to_next_byte > byte_break_if_less_than) + return true; + + uint64_t bits = 0; + if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false; + + assert(bits == 0); + if (bits != 0) return false; + + return true; +} + +spv_result_t MarkvDecoder::DecodeModule(std::vector* spirv_binary) { + const bool header_read_success = + reader_.ReadUnencoded(&header_.magic_number) && + reader_.ReadUnencoded(&header_.markv_version) && + reader_.ReadUnencoded(&header_.markv_model) && + reader_.ReadUnencoded(&header_.markv_length_in_bits) && + reader_.ReadUnencoded(&header_.spirv_version) && + reader_.ReadUnencoded(&header_.spirv_generator); + + if (!header_read_success) + return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header"; + + if (header_.markv_length_in_bits == 0) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Header markv_length_in_bits field is zero"; + + if (header_.magic_number != MarkvCodec::kMarkvMagicNumber) + return Diag(SPV_ERROR_INVALID_BINARY) + << "MARK-V binary has incorrect magic number"; + + // TODO(atgoo@github.com): Print version strings. + if (header_.markv_version != MarkvCodec::GetMarkvVersion()) + return Diag(SPV_ERROR_INVALID_BINARY) + << "MARK-V binary and the codec have different versions"; + + const uint32_t model_type = header_.markv_model >> 16; + const uint32_t model_version = header_.markv_model & 0xFFFF; + if (model_type != model_->model_type()) + return Diag(SPV_ERROR_INVALID_BINARY) + << "MARK-V binary and the codec use different MARK-V models"; + + if (model_version != model_->model_version()) + return Diag(SPV_ERROR_INVALID_BINARY) + << "MARK-V binary and the codec use different versions if the same " + << "MARK-V model"; + + spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic. + spirv_.resize(5, 0); + spirv_[0] = SpvMagicNumber; + spirv_[1] = header_.spirv_version; + spirv_[2] = header_.spirv_generator; + + if (logger_) { + reader_.SetCallback( + [this](const std::string& str) { logger_->AppendBitSequence(str); }); + } + + while (reader_.GetNumReadBits() < header_.markv_length_in_bits) { + inst_ = {}; + const spv_result_t decode_result = DecodeInstruction(); + if (decode_result != SPV_SUCCESS) return decode_result; + } + + if (validator_options_) { + spv_const_binary_t validation_binary = {spirv_.data(), spirv_.size()}; + const spv_result_t result = spvValidateWithOptions( + context_, validator_options_, &validation_binary, nullptr); + if (result != SPV_SUCCESS) return result; + } + + // Validate the decode binary + if (reader_.GetNumReadBits() != header_.markv_length_in_bits || + !reader_.OnlyZeroesLeft()) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "MARK-V binary has wrong stated bit length " + << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits; + } + + // Decoding of the module is finished, validation state should have correct + // id bound. + spirv_[3] = GetIdBound(); + + *spirv_binary = std::move(spirv_); + return SPV_SUCCESS; +} + +// TODO(atgoo@github.com): The implementation borrows heavily from +// Parser::parseOperand. +// Consider coupling them together in some way once MARK-V codec is more mature. +// For now it's better to keep the code independent for experimentation +// purposes. +spv_result_t MarkvDecoder::DecodeOperand( + size_t operand_offset, const spv_operand_type_t type, + spv_operand_pattern_t* expected_operands) { + const SpvOp opcode = static_cast(inst_.opcode); + + memset(&operand_, 0, sizeof(operand_)); + + assert((operand_offset >> 16) == 0); + operand_.offset = static_cast(operand_offset); + operand_.type = type; + + // Set default values, may be updated later. + operand_.number_kind = SPV_NUMBER_NONE; + operand_.number_bit_width = 0; + + const size_t first_word_index = inst_words_.size(); + + switch (type) { + case SPV_OPERAND_TYPE_RESULT_ID: { + const spv_result_t result = DecodeResultId(); + if (result != SPV_SUCCESS) return result; + + inst_words_.push_back(inst_.result_id); + SetIdBound(std::max(GetIdBound(), inst_.result_id + 1)); + PromoteIfNeeded(inst_.result_id); + break; + } + + case SPV_OPERAND_TYPE_TYPE_ID: { + const spv_result_t result = DecodeTypeId(); + if (result != SPV_SUCCESS) return result; + + inst_words_.push_back(inst_.type_id); + SetIdBound(std::max(GetIdBound(), inst_.type_id + 1)); + PromoteIfNeeded(inst_.type_id); + break; + } + + case SPV_OPERAND_TYPE_ID: + case SPV_OPERAND_TYPE_OPTIONAL_ID: + case SPV_OPERAND_TYPE_SCOPE_ID: + case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { + uint32_t id = 0; + const spv_result_t result = DecodeRefId(&id); + if (result != SPV_SUCCESS) return result; + + if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0"; + + if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) { + operand_.type = SPV_OPERAND_TYPE_ID; + + if (opcode == SpvOpExtInst && operand_.offset == 3) { + // The current word is the extended instruction set id. + // Set the extended instruction set type for the current + // instruction. + auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id); + if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) { + return Diag(SPV_ERROR_INVALID_ID) + << "OpExtInst set id " << id + << " does not reference an OpExtInstImport result Id"; + } + inst_.ext_inst_type = ext_inst_type_iter->second; + } + } + + inst_words_.push_back(id); + SetIdBound(std::max(GetIdBound(), id + 1)); + PromoteIfNeeded(id); + break; + } + + case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: { + uint32_t word = 0; + const spv_result_t result = DecodeNonIdWord(&word); + if (result != SPV_SUCCESS) return result; + + inst_words_.push_back(word); + + assert(SpvOpExtInst == opcode); + assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE); + spv_ext_inst_desc ext_inst; + if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst)) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid extended instruction number: " << word; + spvPushOperandTypes(ext_inst->operandTypes, expected_operands); + break; + } + + case SPV_OPERAND_TYPE_LITERAL_INTEGER: + case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: { + // These are regular single-word literal integer operands. + // Post-parsing validation should check the range of the parsed value. + operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER; + // It turns out they are always unsigned integers! + operand_.number_kind = SPV_NUMBER_UNSIGNED_INT; + operand_.number_bit_width = 32; + + uint32_t word = 0; + const spv_result_t result = DecodeNonIdWord(&word); + if (result != SPV_SUCCESS) return result; + + inst_words_.push_back(word); + break; + } + + case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: + case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: { + operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER; + if (opcode == SpvOpSwitch) { + // The literal operands have the same type as the value + // referenced by the selector Id. + const uint32_t selector_id = inst_words_.at(1); + const auto type_id_iter = id_to_type_id_.find(selector_id); + if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid OpSwitch: selector id " << selector_id + << " has no type"; + } + uint32_t type_id = type_id_iter->second; + + if (selector_id == type_id) { + // Recall that by convention, a result ID that is a type definition + // maps to itself. + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid OpSwitch: selector id " << selector_id + << " is a type, not a value"; + } + if (auto error = SetNumericTypeInfoForType(&operand_, type_id)) + return error; + if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT && + operand_.number_kind != SPV_NUMBER_SIGNED_INT) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid OpSwitch: selector id " << selector_id + << " is not a scalar integer"; + } + } else { + assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant); + // The literal number type is determined by the type Id for the + // constant. + assert(inst_.type_id); + if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id)) + return error; + } + + if (auto error = DecodeLiteralNumber(operand_)) return error; + + break; + } + + case SPV_OPERAND_TYPE_LITERAL_STRING: + case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: { + operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING; + std::vector str; + auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode); + + if (codec) { + std::string decoded_string; + const bool huffman_result = + codec->DecodeFromStream(GetReadBitCallback(), &decoded_string); + assert(huffman_result); + if (!huffman_result) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read literal string"; + + if (decoded_string != "kMarkvNoneOfTheAbove") { + std::copy(decoded_string.begin(), decoded_string.end(), + std::back_inserter(str)); + str.push_back('\0'); + } + } + + // The loop is expected to terminate once we encounter '\0' or exhaust + // the bit stream. + if (str.empty()) { + while (true) { + char ch = 0; + if (!reader_.ReadUnencoded(&ch)) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read literal string"; + + str.push_back(ch); + + if (ch == '\0') break; + } + } + + while (str.size() % 4 != 0) str.push_back('\0'); + + inst_words_.resize(inst_words_.size() + str.size() / 4); + std::memcpy(&inst_words_[first_word_index], str.data(), str.size()); + + if (SpvOpExtInstImport == opcode) { + // Record the extended instruction type for the ID for this import. + // There is only one string literal argument to OpExtInstImport, + // so it's sufficient to guard this just on the opcode. + const spv_ext_inst_type_t ext_inst_type = + spvExtInstImportTypeGet(str.data()); + if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid extended instruction import '" << str.data() + << "'"; + } + // We must have parsed a valid result ID. It's a condition + // of the grammar, and we only accept non-zero result Ids. + assert(inst_.result_id); + const bool inserted = + import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type) + .second; + (void)inserted; + assert(inserted); + } + break; + } + + case SPV_OPERAND_TYPE_CAPABILITY: + case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: + case SPV_OPERAND_TYPE_EXECUTION_MODEL: + case SPV_OPERAND_TYPE_ADDRESSING_MODEL: + case SPV_OPERAND_TYPE_MEMORY_MODEL: + case SPV_OPERAND_TYPE_EXECUTION_MODE: + case SPV_OPERAND_TYPE_STORAGE_CLASS: + case SPV_OPERAND_TYPE_DIMENSIONALITY: + case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: + case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: + case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: + case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: + case SPV_OPERAND_TYPE_LINKAGE_TYPE: + case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: + case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: + case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: + case SPV_OPERAND_TYPE_DECORATION: + case SPV_OPERAND_TYPE_BUILT_IN: + case SPV_OPERAND_TYPE_GROUP_OPERATION: + case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: + case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: { + // A single word that is a plain enum value. + uint32_t word = 0; + const spv_result_t result = DecodeNonIdWord(&word); + if (result != SPV_SUCCESS) return result; + + inst_words_.push_back(word); + + // Map an optional operand type to its corresponding concrete type. + if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER) + operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER; + + spv_operand_desc entry; + if (grammar_.lookupOperand(type, word, &entry)) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid " << spvOperandTypeStr(operand_.type) + << " operand: " << word; + } + + // Prepare to accept operands to this operand, if needed. + spvPushOperandTypes(entry->operandTypes, expected_operands); + break; + } + + case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: + case SPV_OPERAND_TYPE_FUNCTION_CONTROL: + case SPV_OPERAND_TYPE_LOOP_CONTROL: + case SPV_OPERAND_TYPE_IMAGE: + case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: + case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: + case SPV_OPERAND_TYPE_SELECTION_CONTROL: { + // This operand is a mask. + uint32_t word = 0; + const spv_result_t result = DecodeNonIdWord(&word); + if (result != SPV_SUCCESS) return result; + + inst_words_.push_back(word); + + // Map an optional operand type to its corresponding concrete type. + if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE) + operand_.type = SPV_OPERAND_TYPE_IMAGE; + else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS) + operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS; + + // Check validity of set mask bits. Also prepare for operands for those + // masks if they have any. To get operand order correct, scan from + // MSB to LSB since we can only prepend operands to a pattern. + // The only case in the grammar where you have more than one mask bit + // having an operand is for image operands. See SPIR-V 3.14 Image + // Operands. + uint32_t remaining_word = word; + for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) { + if (remaining_word & mask) { + spv_operand_desc entry; + if (grammar_.lookupOperand(type, mask, &entry)) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid " << spvOperandTypeStr(operand_.type) + << " operand: " << word << " has invalid mask component " + << mask; + } + remaining_word ^= mask; + spvPushOperandTypes(entry->operandTypes, expected_operands); + } + } + if (word == 0) { + // An all-zeroes mask *might* also be valid. + spv_operand_desc entry; + if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) { + // Prepare for its operands, if any. + spvPushOperandTypes(entry->operandTypes, expected_operands); + } + } + break; + } + default: + return Diag(SPV_ERROR_INVALID_BINARY) + << "Internal error: Unhandled operand type: " << type; + } + + operand_.num_words = uint16_t(inst_words_.size() - first_word_index); + + assert(spvOperandIsConcrete(operand_.type)); + + parsed_operands_.push_back(operand_); + + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeInstruction() { + parsed_operands_.clear(); + inst_words_.clear(); + + // Opcode/num_words placeholder, the word will be filled in later. + inst_words_.push_back(0); + + bool num_operands_still_unknown = true; + { + uint32_t opcode = 0; + uint32_t num_operands = 0; + + const spv_result_t opcode_decoding_result = + DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands); + if (opcode_decoding_result < 0) return opcode_decoding_result; + + if (opcode_decoding_result == SPV_SUCCESS) { + inst_.num_operands = static_cast(num_operands); + num_operands_still_unknown = false; + } else { + if (!reader_.ReadVariableWidthU32(&opcode, + model_->opcode_chunk_length())) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read opcode of instruction"; + } + } + + inst_.opcode = static_cast(opcode); + } + + const SpvOp opcode = static_cast(inst_.opcode); + + spv_opcode_desc opcode_desc; + if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) { + return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode"; + } + + spv_operand_pattern_t expected_operands; + expected_operands.reserve(opcode_desc->numTypes); + for (auto i = 0; i < opcode_desc->numTypes; i++) { + expected_operands.push_back( + opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]); + } + + if (num_operands_still_unknown) { + if (!OpcodeHasFixedNumberOfOperands(opcode)) { + if (!reader_.ReadVariableWidthU16(&inst_.num_operands, + model_->num_operands_chunk_length())) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read num_operands of instruction"; + } else { + inst_.num_operands = static_cast(expected_operands.size()); + } + } + + for (operand_index_ = 0; + operand_index_ < static_cast(inst_.num_operands); + ++operand_index_) { + assert(!expected_operands.empty()); + const spv_operand_type_t type = + spvTakeFirstMatchableOperand(&expected_operands); + + const size_t operand_offset = inst_words_.size(); + + const spv_result_t decode_result = + DecodeOperand(operand_offset, type, &expected_operands); + + if (decode_result != SPV_SUCCESS) return decode_result; + } + + assert(inst_.num_operands == parsed_operands_.size()); + + // Only valid while inst_words_ and parsed_operands_ remain unchanged (until + // next DecodeInstruction call). + inst_.words = inst_words_.data(); + inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data(); + inst_.num_words = static_cast(inst_words_.size()); + inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode)); + + std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_)); + + assert(inst_.num_words == + std::accumulate( + parsed_operands_.begin(), parsed_operands_.end(), 1, + [](int num_words, const spv_parsed_operand_t& operand) { + return num_words += operand.num_words; + }) && + "num_words in instruction doesn't correspond to the sum of num_words" + "in the operands"); + + RecordNumberType(); + ProcessCurInstruction(); + + if (!ReadToByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte)) + return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break"; + + if (logger_) { + logger_->NewLine(); + std::stringstream ss; + ss << spvOpcodeString(opcode) << " "; + for (size_t index = 1; index < inst_words_.size(); ++index) + ss << inst_words_[index] << " "; + logger_->AppendText(ss.str()); + logger_->NewLine(); + logger_->NewLine(); + if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION; + } + + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::SetNumericTypeInfoForType( + spv_parsed_operand_t* parsed_operand, uint32_t type_id) { + assert(type_id != 0); + auto type_info_iter = type_id_to_number_type_info_.find(type_id); + if (type_info_iter == type_id_to_number_type_info_.end()) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Type Id " << type_id << " is not a type"; + } + + const NumberType& info = type_info_iter->second; + if (info.type == SPV_NUMBER_NONE) { + // This is a valid type, but for something other than a scalar number. + return Diag(SPV_ERROR_INVALID_BINARY) + << "Type Id " << type_id << " is not a scalar numeric type"; + } + + parsed_operand->number_kind = info.type; + parsed_operand->number_bit_width = info.bit_width; + // Round up the word count. + parsed_operand->num_words = static_cast((info.bit_width + 31) / 32); + return SPV_SUCCESS; +} + +void MarkvDecoder::RecordNumberType() { + const SpvOp opcode = static_cast(inst_.opcode); + if (spvOpcodeGeneratesType(opcode)) { + NumberType info = {SPV_NUMBER_NONE, 0}; + if (SpvOpTypeInt == opcode) { + info.bit_width = inst_.words[inst_.operands[1].offset]; + info.type = inst_.words[inst_.operands[2].offset] + ? SPV_NUMBER_SIGNED_INT + : SPV_NUMBER_UNSIGNED_INT; + } else if (SpvOpTypeFloat == opcode) { + info.bit_width = inst_.words[inst_.operands[1].offset]; + info.type = SPV_NUMBER_FLOATING; + } + // The *result* Id of a type generating instruction is the type Id. + type_id_to_number_type_info_[inst_.result_id] = info; + } +} + +} // namespace comp +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/comp/markv_decoder.h b/3rdparty/spirv-tools/source/comp/markv_decoder.h new file mode 100644 index 000000000..4d8402b44 --- /dev/null +++ b/3rdparty/spirv-tools/source/comp/markv_decoder.h @@ -0,0 +1,175 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/comp/bit_stream.h" +#include "source/comp/markv.h" +#include "source/comp/markv_codec.h" +#include "source/comp/markv_logger.h" +#include "source/util/make_unique.h" + +#ifndef SOURCE_COMP_MARKV_DECODER_H_ +#define SOURCE_COMP_MARKV_DECODER_H_ + +namespace spvtools { +namespace comp { + +class MarkvLogger; + +// Decodes MARK-V buffers written by MarkvEncoder. +class MarkvDecoder : public MarkvCodec { + public: + // |model| is owned by the caller, must be not null and valid during the + // lifetime of MarkvEncoder. + MarkvDecoder(spv_const_context context, const std::vector& markv, + const MarkvCodecOptions& options, const MarkvModel* model) + : MarkvCodec(context, GetValidatorOptions(options), model), + options_(options), + reader_(markv) { + SetIdBound(1); + parsed_operands_.reserve(25); + inst_words_.reserve(25); + } + ~MarkvDecoder() = default; + + // Creates an internal logger which writes comments on the decoding process. + void CreateLogger(MarkvLogConsumer log_consumer, + MarkvDebugConsumer debug_consumer) { + logger_ = MakeUnique(log_consumer, debug_consumer); + } + + // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|. + // Can be called only once. Fails if data of wrong format or ends prematurely, + // of if validation fails. + spv_result_t DecodeModule(std::vector* spirv_binary); + + // Creates and returns validator options. Returned value owned by the caller. + static spv_validator_options GetValidatorOptions( + const MarkvCodecOptions& options) { + return options.validate_spirv_binary ? spvValidatorOptionsCreate() + : nullptr; + } + + private: + // Describes the format of a typed literal number. + struct NumberType { + spv_number_kind_t type; + uint32_t bit_width; + }; + + // Reads a single bit from reader_. The read bit is stored in |bit|. + // Returns false iff reader_ fails. + bool ReadBit(bool* bit) { + uint64_t bits = 0; + const bool result = reader_.ReadBits(&bits, 1); + if (result) *bit = bits ? true : false; + return result; + }; + + // Returns ReadBit bound to the class object. + std::function GetReadBitCallback() { + return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1); + } + + // Reads a single non-id word from bit stream. operand_.type determines if + // the word needs to be decoded and how. + spv_result_t DecodeNonIdWord(uint32_t* word); + + // Reads and decodes both opcode and num_operands as a single code. + // Returns SPV_UNSUPPORTED iff no suitable codec was found. + spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode, + uint32_t* num_operands); + + // Reads mtf rank from bit stream. |mtf| is used to determine the codec + // scheme. |fallback_method| is used if no codec defined for |mtf|. + spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method, + uint32_t* rank); + + // Reads id using coding based on mtf associated with the id descriptor. + // Returns SPV_UNSUPPORTED iff fallback method needs to be used. + spv_result_t DecodeIdWithDescriptor(uint32_t* id); + + // Reads id using coding based on the given |mtf|, which is expected to + // contain the needed |id|. + spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id); + + // Reads type id of the current instruction if can't be inferred. + spv_result_t DecodeTypeId(); + + // Reads result id of the current instruction if can't be inferred. + spv_result_t DecodeResultId(); + + // Reads id which is neither type nor result id. + spv_result_t DecodeRefId(uint32_t* id); + + // Reads and discards bits until the beginning of the next byte if the + // number of bits until the next byte is less than |byte_break_if_less_than|. + bool ReadToByteBreak(size_t byte_break_if_less_than); + + // Returns instruction words decoded up to this point. + const uint32_t* GetInstWords() const override { return inst_words_.data(); } + + // Reads a literal number as it is described in |operand| from the bit stream, + // decodes and writes it to spirv_. + spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand); + + // Reads instruction from bit stream, decodes and validates it. + // Decoded instruction is valid until the next call of DecodeInstruction(). + spv_result_t DecodeInstruction(); + + // Read operand from the stream decodes and validates it. + spv_result_t DecodeOperand(size_t operand_offset, + const spv_operand_type_t type, + spv_operand_pattern_t* expected_operands); + + // Records the numeric type for an operand according to the type information + // associated with the given non-zero type Id. This can fail if the type Id + // is not a type Id, or if the type Id does not reference a scalar numeric + // type. On success, return SPV_SUCCESS and populates the num_words, + // number_kind, and number_bit_width fields of parsed_operand. + spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand, + uint32_t type_id); + + // Records the number type for the current instruction, if it generates a + // type. For types that aren't scalar numbers, record something with number + // kind SPV_NUMBER_NONE. + void RecordNumberType(); + + MarkvCodecOptions options_; + + // Temporary sink where decoded SPIR-V words are written. Once it contains the + // entire module, the container is moved and returned. + std::vector spirv_; + + // Bit stream containing encoded data. + BitReaderWord64 reader_; + + // Temporary storage for operands of the currently parsed instruction. + // Valid until next DecodeInstruction call. + std::vector parsed_operands_; + + // Temporary storage for current instruction words. + // Valid until next DecodeInstruction call. + std::vector inst_words_; + + // Maps a type ID to its number type description. + std::unordered_map type_id_to_number_type_info_; + + // Maps an ExtInstImport id to the extended instruction type. + std::unordered_map import_id_to_ext_inst_type_; +}; + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_MARKV_DECODER_H_ diff --git a/3rdparty/spirv-tools/source/comp/markv_encoder.cpp b/3rdparty/spirv-tools/source/comp/markv_encoder.cpp new file mode 100644 index 000000000..1abd58646 --- /dev/null +++ b/3rdparty/spirv-tools/source/comp/markv_encoder.cpp @@ -0,0 +1,486 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/comp/markv_encoder.h" + +#include "source/binary.h" +#include "source/opcode.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace comp { +namespace { + +const size_t kCommentNumWhitespaces = 2; + +} // namespace + +spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) { + auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_); + + if (codec) { + uint64_t bits = 0; + size_t num_bits = 0; + if (codec->Encode(word, &bits, &num_bits)) { + // Encoding successful. + writer_.WriteBits(bits, num_bits); + return SPV_SUCCESS; + } else { + // Encoding failed, write kMarkvNoneOfTheAbove flag. + if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, + &num_bits)) + return Diag(SPV_ERROR_INTERNAL) + << "Non-id word Huffman table for " + << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index " + << operand_index_ << " is missing kMarkvNoneOfTheAbove"; + writer_.WriteBits(bits, num_bits); + } + } + + // Fallback encoding. + const size_t chunk_length = + model_->GetOperandVariableWidthChunkLength(operand_.type); + if (chunk_length) { + writer_.WriteVariableWidthU32(word, chunk_length); + } else { + writer_.WriteUnencoded(word); + } + return SPV_SUCCESS; +} + +spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(uint32_t opcode, + uint32_t num_operands) { + uint64_t bits = 0; + size_t num_bits = 0; + + const uint32_t word = opcode | (num_operands << 16); + + // First try to use the Markov chain codec. + auto* codec = + model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode()); + if (codec) { + if (codec->Encode(word, &bits, &num_bits)) { + // The word was successfully encoded into bits/num_bits. + writer_.WriteBits(bits, num_bits); + return SPV_SUCCESS; + } else { + // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove + // and use fallback encoding. + if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, + &num_bits)) + return Diag(SPV_ERROR_INTERNAL) + << "opcode_and_num_operands Huffman table for " + << spvOpcodeString(GetPrevOpcode()) + << "is missing kMarkvNoneOfTheAbove"; + writer_.WriteBits(bits, num_bits); + } + } + + // Fallback to base-rate codec. + codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop); + assert(codec); + if (codec->Encode(word, &bits, &num_bits)) { + // The word was successfully encoded into bits/num_bits. + writer_.WriteBits(bits, num_bits); + return SPV_SUCCESS; + } else { + // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove + // and return false. + if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, &num_bits)) + return Diag(SPV_ERROR_INTERNAL) + << "Global opcode_and_num_operands Huffman table is missing " + << "kMarkvNoneOfTheAbove"; + writer_.WriteBits(bits, num_bits); + return SPV_UNSUPPORTED; + } +} + +spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf, + uint64_t fallback_method) { + const auto* codec = GetMtfHuffmanCodec(mtf); + if (!codec) { + assert(fallback_method != kMtfNone); + codec = GetMtfHuffmanCodec(fallback_method); + } + + if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank"; + + uint64_t bits = 0; + size_t num_bits = 0; + if (rank < MarkvCodec::kMtfSmallestRankEncodedByValue) { + // Encode using Huffman coding. + if (!codec->Encode(rank, &bits, &num_bits)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to encode MTF rank with Huffman"; + + writer_.WriteBits(bits, num_bits); + } else { + // Encode by value. + if (!codec->Encode(MarkvCodec::kMtfRankEncodedByValueSignal, &bits, + &num_bits)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to encode kMtfRankEncodedByValueSignal"; + + writer_.WriteBits(bits, num_bits); + writer_.WriteVariableWidthU32( + rank - MarkvCodec::kMtfSmallestRankEncodedByValue, + model_->mtf_rank_chunk_length()); + } + return SPV_SUCCESS; +} + +spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) { + // Get the descriptor for id. + const uint32_t long_descriptor = long_id_descriptors_.GetDescriptor(id); + auto* codec = + model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_); + uint64_t bits = 0; + size_t num_bits = 0; + uint64_t mtf = kMtfNone; + if (long_descriptor && codec && + codec->Encode(long_descriptor, &bits, &num_bits)) { + // If the descriptor exists and is in the table, write the descriptor and + // proceed to encoding the rank. + writer_.WriteBits(bits, num_bits); + mtf = GetMtfLongIdDescriptor(long_descriptor); + } else { + if (codec) { + // The descriptor doesn't exist or we have no coding for it. Write + // kMarkvNoneOfTheAbove and go to fallback method. + if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, + &num_bits)) + return Diag(SPV_ERROR_INTERNAL) + << "Descriptor Huffman table for " + << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index " + << operand_index_ << " is missing kMarkvNoneOfTheAbove"; + + writer_.WriteBits(bits, num_bits); + } + + if (model_->id_fallback_strategy() != + MarkvModel::IdFallbackStrategy::kShortDescriptor) { + return SPV_UNSUPPORTED; + } + + const uint32_t short_descriptor = short_id_descriptors_.GetDescriptor(id); + writer_.WriteBits(short_descriptor, MarkvCodec::kShortDescriptorNumBits); + + if (short_descriptor == 0) { + // Forward declared id. + return SPV_UNSUPPORTED; + } + + mtf = GetMtfShortIdDescriptor(short_descriptor); + } + + // Descriptor has been encoded. Now encode the rank of the id in the + // associated mtf sequence. + return EncodeExistingId(mtf, id); +} + +spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) { + assert(multi_mtf_.GetSize(mtf) > 0); + if (multi_mtf_.GetSize(mtf) == 1) { + // If the sequence has only one element no need to write rank, the decoder + // would make the same decision. + return SPV_SUCCESS; + } + + uint32_t rank = 0; + if (!multi_mtf_.RankFromValue(mtf, id, &rank)) + return Diag(SPV_ERROR_INTERNAL) << "Id is not in the MTF sequence"; + + return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank); +} + +spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) { + { + // Try to encode using id descriptor mtfs. + const spv_result_t result = EncodeIdWithDescriptor(id); + if (result != SPV_UNSUPPORTED) return result; + // If can't be done continue with other methods. + } + + const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction( + SpvOp(inst_.opcode))(operand_index_); + uint32_t rank = 0; + + if (model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased) { + // Encode using rule-based mtf. + uint64_t mtf = GetRuleBasedMtf(); + + if (mtf != kMtfNone && !can_forward_declare) { + assert(multi_mtf_.HasValue(kMtfAll, id)); + return EncodeExistingId(mtf, id); + } + + if (mtf == kMtfNone) mtf = kMtfAll; + + if (!multi_mtf_.RankFromValue(mtf, id, &rank)) { + // This is the first occurrence of a forward declared id. + multi_mtf_.Insert(kMtfAll, id); + multi_mtf_.Insert(kMtfForwardDeclared, id); + if (mtf != kMtfAll) multi_mtf_.Insert(mtf, id); + rank = 0; + } + + return EncodeMtfRankHuffman(rank, mtf, kMtfAll); + } else { + assert(can_forward_declare); + + if (!multi_mtf_.RankFromValue(kMtfForwardDeclared, id, &rank)) { + // This is the first occurrence of a forward declared id. + multi_mtf_.Insert(kMtfForwardDeclared, id); + rank = 0; + } + + writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length()); + return SPV_SUCCESS; + } +} + +spv_result_t MarkvEncoder::EncodeTypeId() { + if (inst_.opcode == SpvOpFunctionParameter) { + assert(!remaining_function_parameter_types_.empty()); + assert(inst_.type_id == remaining_function_parameter_types_.front()); + remaining_function_parameter_types_.pop_front(); + return SPV_SUCCESS; + } + + { + // Try to encode using id descriptor mtfs. + const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id); + if (result != SPV_UNSUPPORTED) return result; + // If can't be done continue with other methods. + } + + assert(model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased); + + uint64_t mtf = GetRuleBasedMtf(); + assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))( + operand_index_)); + + if (mtf == kMtfNone) { + mtf = kMtfTypeNonFunction; + // Function types should have been handled by GetRuleBasedMtf. + assert(inst_.opcode != SpvOpFunction); + } + + return EncodeExistingId(mtf, inst_.type_id); +} + +spv_result_t MarkvEncoder::EncodeResultId() { + uint32_t rank = 0; + + const uint64_t num_still_forward_declared = + multi_mtf_.GetSize(kMtfForwardDeclared); + + if (num_still_forward_declared) { + // We write the rank only if kMtfForwardDeclared is not empty. If it is + // empty the decoder knows that there are no forward declared ids to expect. + if (multi_mtf_.RankFromValue(kMtfForwardDeclared, inst_.result_id, &rank)) { + // This is a definition of a forward declared id. We can remove the id + // from kMtfForwardDeclared. + if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to remove id from kMtfForwardDeclared"; + writer_.WriteBits(1, 1); + writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length()); + } else { + rank = 0; + writer_.WriteBits(0, 1); + } + } + + if (model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased) { + if (!rank) { + multi_mtf_.Insert(kMtfAll, inst_.result_id); + } + } + + return SPV_SUCCESS; +} + +spv_result_t MarkvEncoder::EncodeLiteralNumber( + const spv_parsed_operand_t& operand) { + if (operand.number_bit_width <= 32) { + const uint32_t word = inst_.words[operand.offset]; + return EncodeNonIdWord(word); + } else { + assert(operand.number_bit_width <= 64); + const uint64_t word = uint64_t(inst_.words[operand.offset]) | + (uint64_t(inst_.words[operand.offset + 1]) << 32); + if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { + writer_.WriteVariableWidthU64(word, model_->u64_chunk_length()); + } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { + int64_t val = 0; + std::memcpy(&val, &word, 8); + writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(), + model_->s64_block_exponent()); + } else if (operand.number_kind == SPV_NUMBER_FLOATING) { + writer_.WriteUnencoded(word); + } else { + return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length"; + } + } + return SPV_SUCCESS; +} + +void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) { + const size_t num_bits_to_next_byte = + GetNumBitsToNextByte(writer_.GetNumBits()); + if (num_bits_to_next_byte == 0 || + num_bits_to_next_byte > byte_break_if_less_than) + return; + + if (logger_) { + logger_->AppendWhitespaces(kCommentNumWhitespaces); + logger_->AppendText(""); + } + + writer_.WriteBits(0, num_bits_to_next_byte); +} + +spv_result_t MarkvEncoder::EncodeInstruction( + const spv_parsed_instruction_t& inst) { + SpvOp opcode = SpvOp(inst.opcode); + inst_ = inst; + + LogDisassemblyInstruction(); + + const spv_result_t opcode_encodig_result = + EncodeOpcodeAndNumOperands(opcode, inst.num_operands); + if (opcode_encodig_result < 0) return opcode_encodig_result; + + if (opcode_encodig_result != SPV_SUCCESS) { + // Fallback encoding for opcode and num_operands. + writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length()); + + if (!OpcodeHasFixedNumberOfOperands(opcode)) { + // If the opcode has a variable number of operands, encode the number of + // operands with the instruction. + + if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces); + + writer_.WriteVariableWidthU16(inst.num_operands, + model_->num_operands_chunk_length()); + } + } + + // Write operands. + const uint32_t num_operands = inst_.num_operands; + for (operand_index_ = 0; operand_index_ < num_operands; ++operand_index_) { + operand_ = inst_.operands[operand_index_]; + + if (logger_) { + logger_->AppendWhitespaces(kCommentNumWhitespaces); + logger_->AppendText("<"); + logger_->AppendText(spvOperandTypeStr(operand_.type)); + logger_->AppendText(">"); + } + + switch (operand_.type) { + case SPV_OPERAND_TYPE_RESULT_ID: + case SPV_OPERAND_TYPE_TYPE_ID: + case SPV_OPERAND_TYPE_ID: + case SPV_OPERAND_TYPE_OPTIONAL_ID: + case SPV_OPERAND_TYPE_SCOPE_ID: + case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { + const uint32_t id = inst_.words[operand_.offset]; + if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) { + const spv_result_t result = EncodeTypeId(); + if (result != SPV_SUCCESS) return result; + } else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) { + const spv_result_t result = EncodeResultId(); + if (result != SPV_SUCCESS) return result; + } else { + const spv_result_t result = EncodeRefId(id); + if (result != SPV_SUCCESS) return result; + } + + PromoteIfNeeded(id); + break; + } + + case SPV_OPERAND_TYPE_LITERAL_INTEGER: { + const spv_result_t result = + EncodeNonIdWord(inst_.words[operand_.offset]); + if (result != SPV_SUCCESS) return result; + break; + } + + case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: { + const spv_result_t result = EncodeLiteralNumber(operand_); + if (result != SPV_SUCCESS) return result; + break; + } + + case SPV_OPERAND_TYPE_LITERAL_STRING: { + const char* src = + reinterpret_cast(&inst_.words[operand_.offset]); + + auto* codec = model_->GetLiteralStringHuffmanCodec(opcode); + if (codec) { + uint64_t bits = 0; + size_t num_bits = 0; + const std::string str = src; + if (codec->Encode(str, &bits, &num_bits)) { + writer_.WriteBits(bits, num_bits); + break; + } else { + bool result = + codec->Encode("kMarkvNoneOfTheAbove", &bits, &num_bits); + (void)result; + assert(result); + writer_.WriteBits(bits, num_bits); + } + } + + const size_t length = spv_strnlen_s(src, operand_.num_words * 4); + if (length == operand_.num_words * 4) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to find terminal character of literal string"; + for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]); + break; + } + + default: { + for (int i = 0; i < operand_.num_words; ++i) { + const uint32_t word = inst_.words[operand_.offset + i]; + const spv_result_t result = EncodeNonIdWord(word); + if (result != SPV_SUCCESS) return result; + } + break; + } + } + } + + AddByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte); + + if (logger_) { + logger_->NewLine(); + logger_->NewLine(); + if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION; + } + + ProcessCurInstruction(); + + return SPV_SUCCESS; +} + +} // namespace comp +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/comp/markv_encoder.h b/3rdparty/spirv-tools/source/comp/markv_encoder.h new file mode 100644 index 000000000..21843123f --- /dev/null +++ b/3rdparty/spirv-tools/source/comp/markv_encoder.h @@ -0,0 +1,167 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/comp/bit_stream.h" +#include "source/comp/markv.h" +#include "source/comp/markv_codec.h" +#include "source/comp/markv_logger.h" +#include "source/util/make_unique.h" + +#ifndef SOURCE_COMP_MARKV_ENCODER_H_ +#define SOURCE_COMP_MARKV_ENCODER_H_ + +#include + +namespace spvtools { +namespace comp { + +// SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and +// EncodeInstruction which can be used as callback by spvBinaryParse. +// Encoded binary is written to an internally maintained bitstream. +// After the last instruction is encoded, the resulting MARK-V binary can be +// acquired by calling GetMarkvBinary(). +// +// The encoder uses SPIR-V validator to keep internal state, therefore +// SPIR-V binary needs to be able to pass validator checks. +// CreateCommentsLogger() can be used to enable the encoder to write comments +// on how encoding was done, which can later be accessed with GetComments(). +class MarkvEncoder : public MarkvCodec { + public: + // |model| is owned by the caller, must be not null and valid during the + // lifetime of MarkvEncoder. + MarkvEncoder(spv_const_context context, const MarkvCodecOptions& options, + const MarkvModel* model) + : MarkvCodec(context, GetValidatorOptions(options), model), + options_(options) {} + ~MarkvEncoder() override = default; + + // Writes data from SPIR-V header to MARK-V header. + spv_result_t EncodeHeader(spv_endianness_t /* endian */, uint32_t /* magic */, + uint32_t version, uint32_t generator, + uint32_t id_bound, uint32_t /* schema */) { + SetIdBound(id_bound); + header_.spirv_version = version; + header_.spirv_generator = generator; + return SPV_SUCCESS; + } + + // Creates an internal logger which writes comments on the encoding process. + void CreateLogger(MarkvLogConsumer log_consumer, + MarkvDebugConsumer debug_consumer) { + logger_ = MakeUnique(log_consumer, debug_consumer); + writer_.SetCallback( + [this](const std::string& str) { logger_->AppendBitSequence(str); }); + } + + // Encodes SPIR-V instruction to MARK-V and writes to bit stream. + // Operation can fail if the instruction fails to pass the validator or if + // the encoder stubmles on something unexpected. + spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst); + + // Concatenates MARK-V header and the bit stream with encoded instructions + // into a single buffer and returns it as spv_markv_binary. The returned + // value is owned by the caller and needs to be destroyed with + // spvMarkvBinaryDestroy(). + std::vector GetMarkvBinary() { + header_.markv_length_in_bits = + static_cast(sizeof(header_) * 8 + writer_.GetNumBits()); + header_.markv_model = + (model_->model_type() << 16) | model_->model_version(); + + const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes(); + std::vector markv(num_bytes); + + assert(writer_.GetData()); + std::memcpy(markv.data(), &header_, sizeof(header_)); + std::memcpy(markv.data() + sizeof(header_), writer_.GetData(), + writer_.GetDataSizeBytes()); + return markv; + } + + // Optionally adds disassembly to the comments. + // Disassembly should contain all instructions in the module separated by + // \n, and no header. + void SetDisassembly(std::string&& disassembly) { + disassembly_ = MakeUnique(std::move(disassembly)); + } + + // Extracts the next instruction line from the disassembly and logs it. + void LogDisassemblyInstruction() { + if (logger_ && disassembly_) { + std::string line; + std::getline(*disassembly_, line, '\n'); + logger_->AppendTextNewLine(line); + } + } + + private: + // Creates and returns validator options. Returned value owned by the caller. + static spv_validator_options GetValidatorOptions( + const MarkvCodecOptions& options) { + return options.validate_spirv_binary ? spvValidatorOptionsCreate() + : nullptr; + } + + // Writes a single word to bit stream. operand_.type determines if the word is + // encoded and how. + spv_result_t EncodeNonIdWord(uint32_t word); + + // Writes both opcode and num_operands as a single code. + // Returns SPV_UNSUPPORTED iff no suitable codec was found. + spv_result_t EncodeOpcodeAndNumOperands(uint32_t opcode, + uint32_t num_operands); + + // Writes mtf rank to bit stream. |mtf| is used to determine the codec + // scheme. |fallback_method| is used if no codec defined for |mtf|. + spv_result_t EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf, + uint64_t fallback_method); + + // Writes id using coding based on mtf associated with the id descriptor. + // Returns SPV_UNSUPPORTED iff fallback method needs to be used. + spv_result_t EncodeIdWithDescriptor(uint32_t id); + + // Writes id using coding based on the given |mtf|, which is expected to + // contain the given |id|. + spv_result_t EncodeExistingId(uint64_t mtf, uint32_t id); + + // Writes type id of the current instruction if can't be inferred. + spv_result_t EncodeTypeId(); + + // Writes result id of the current instruction if can't be inferred. + spv_result_t EncodeResultId(); + + // Writes ids which are neither type nor result ids. + spv_result_t EncodeRefId(uint32_t id); + + // Writes bits to the stream until the beginning of the next byte if the + // number of bits until the next byte is less than |byte_break_if_less_than|. + void AddByteBreak(size_t byte_break_if_less_than); + + // Encodes a literal number operand and writes it to the bit stream. + spv_result_t EncodeLiteralNumber(const spv_parsed_operand_t& operand); + + MarkvCodecOptions options_; + + // Bit stream where encoded instructions are written. + BitWriterWord64 writer_; + + // If not nullptr, disassembled instruction lines will be written to comments. + // Format: \n separated instruction lines, no header. + std::unique_ptr disassembly_; +}; + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_MARKV_ENCODER_H_ diff --git a/3rdparty/spirv-tools/source/comp/markv_logger.h b/3rdparty/spirv-tools/source/comp/markv_logger.h new file mode 100644 index 000000000..c07fe97b7 --- /dev/null +++ b/3rdparty/spirv-tools/source/comp/markv_logger.h @@ -0,0 +1,93 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_COMP_MARKV_LOGGER_H_ +#define SOURCE_COMP_MARKV_LOGGER_H_ + +#include "source/comp/markv.h" + +namespace spvtools { +namespace comp { + +class MarkvLogger { + public: + MarkvLogger(MarkvLogConsumer log_consumer, MarkvDebugConsumer debug_consumer) + : log_consumer_(log_consumer), debug_consumer_(debug_consumer) {} + + void AppendText(const std::string& str) { + Append(str); + use_delimiter_ = false; + } + + void AppendTextNewLine(const std::string& str) { + Append(str); + Append("\n"); + use_delimiter_ = false; + } + + void AppendBitSequence(const std::string& str) { + if (debug_consumer_) instruction_bits_ << str; + if (use_delimiter_) Append("-"); + Append(str); + use_delimiter_ = true; + } + + void AppendWhitespaces(size_t num) { + Append(std::string(num, ' ')); + use_delimiter_ = false; + } + + void NewLine() { + Append("\n"); + use_delimiter_ = false; + } + + bool DebugInstruction(const spv_parsed_instruction_t& inst) { + bool result = true; + if (debug_consumer_) { + result = debug_consumer_( + std::vector(inst.words, inst.words + inst.num_words), + instruction_bits_.str(), instruction_comment_.str()); + instruction_bits_.str(std::string()); + instruction_comment_.str(std::string()); + } + return result; + } + + private: + MarkvLogger(const MarkvLogger&) = delete; + MarkvLogger(MarkvLogger&&) = delete; + MarkvLogger& operator=(const MarkvLogger&) = delete; + MarkvLogger& operator=(MarkvLogger&&) = delete; + + void Append(const std::string& str) { + if (log_consumer_) log_consumer_(str); + if (debug_consumer_) instruction_comment_ << str; + } + + MarkvLogConsumer log_consumer_; + MarkvDebugConsumer debug_consumer_; + + std::stringstream instruction_bits_; + std::stringstream instruction_comment_; + + // If true a delimiter will be appended before the next bit sequence. + // Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0. + bool use_delimiter_ = false; +}; + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_MARKV_LOGGER_H_ diff --git a/3rdparty/spirv-tools/source/comp/markv_model.h b/3rdparty/spirv-tools/source/comp/markv_model.h index 606396e6d..d03df02df 100644 --- a/3rdparty/spirv-tools/source/comp/markv_model.h +++ b/3rdparty/spirv-tools/source/comp/markv_model.h @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Google Inc. +// Copyright (c) 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,18 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_COMP_MARKV_MODEL_H_ -#define LIBSPIRV_COMP_MARKV_MODEL_H_ +#ifndef SOURCE_COMP_MARKV_MODEL_H_ +#define SOURCE_COMP_MARKV_MODEL_H_ -#include #include -#include -#include "latest_version_spirv_header.h" -#include "spirv-tools/libspirv.h" -#include "util/huffman_codec.h" +#include "source/comp/huffman_codec.h" +#include "source/latest_version_spirv_header.h" +#include "spirv-tools/libspirv.hpp" namespace spvtools { +namespace comp { // Base class for MARK-V models. // The class contains encoding/decoding model with various constants and @@ -97,8 +96,8 @@ class MarkvModel { // Returns a codec for common opcode_and_num_operands words for the given // previous opcode. May return nullptr if the codec doesn't exist. - const spvutils::HuffmanCodec* - GetOpcodeAndNumOperandsMarkovHuffmanCodec(uint32_t prev_opcode) const { + const HuffmanCodec* GetOpcodeAndNumOperandsMarkovHuffmanCodec( + uint32_t prev_opcode) const { if (prev_opcode == SpvOpNop) return opcode_and_num_operands_huffman_codec_.get(); @@ -112,7 +111,7 @@ class MarkvModel { // Returns a codec for common non-id words used for given operand slot. // Operand slot is defined by the opcode and the operand index. // May return nullptr if the codec doesn't exist. - const spvutils::HuffmanCodec* GetNonIdWordHuffmanCodec( + const HuffmanCodec* GetNonIdWordHuffmanCodec( uint32_t opcode, uint32_t operand_index) const { const auto it = non_id_word_huffman_codecs_.find( std::pair(opcode, operand_index)); @@ -123,7 +122,7 @@ class MarkvModel { // Returns a codec for common id descriptos used for given operand slot. // Operand slot is defined by the opcode and the operand index. // May return nullptr if the codec doesn't exist. - const spvutils::HuffmanCodec* GetIdDescriptorHuffmanCodec( + const HuffmanCodec* GetIdDescriptorHuffmanCodec( uint32_t opcode, uint32_t operand_index) const { const auto it = id_descriptor_huffman_codecs_.find( std::pair(opcode, operand_index)); @@ -134,7 +133,7 @@ class MarkvModel { // Returns a codec for common strings used by the given opcode. // Operand slot is defined by the opcode and the operand index. // May return nullptr if the codec doesn't exist. - const spvutils::HuffmanCodec* GetLiteralStringHuffmanCodec( + const HuffmanCodec* GetLiteralStringHuffmanCodec( uint32_t opcode) const { const auto it = literal_string_huffman_codecs_.find(opcode); if (it == literal_string_huffman_codecs_.end()) return nullptr; @@ -178,23 +177,23 @@ class MarkvModel { protected: // Huffman codec for base-rate of opcode_and_num_operands. - std::unique_ptr> + std::unique_ptr> opcode_and_num_operands_huffman_codec_; // Huffman codecs for opcode_and_num_operands. The map key is previous opcode. - std::map>> + std::map>> opcode_and_num_operands_markov_huffman_codecs_; // Huffman codecs for non-id single-word operand values. // The map key is pair . std::map, - std::unique_ptr>> + std::unique_ptr>> non_id_word_huffman_codecs_; // Huffman codecs for id descriptors. The map key is pair // . std::map, - std::unique_ptr>> + std::unique_ptr>> id_descriptor_huffman_codecs_; // Set of all descriptors which have a coding scheme in any of @@ -205,7 +204,7 @@ class MarkvModel { // current instruction. This assumes, that there is no more than one literal // string operand per instruction, but would still work even if this is not // the case. Names and debug information strings are not collected. - std::map>> + std::map>> literal_string_huffman_codecs_; // Chunk lengths used for variable width encoding of operands (index is @@ -227,6 +226,7 @@ class MarkvModel { uint32_t model_version_ = 0; }; +} // namespace comp } // namespace spvtools -#endif // LIBSPIRV_COMP_MARKV_MODEL_H_ +#endif // SOURCE_COMP_MARKV_MODEL_H_ diff --git a/3rdparty/spirv-tools/source/comp/move_to_front.cpp b/3rdparty/spirv-tools/source/comp/move_to_front.cpp new file mode 100644 index 000000000..9d35a3f5b --- /dev/null +++ b/3rdparty/spirv-tools/source/comp/move_to_front.cpp @@ -0,0 +1,456 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/comp/move_to_front.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace spvtools { +namespace comp { + +bool MoveToFront::Insert(uint32_t value) { + auto it = value_to_node_.find(value); + if (it != value_to_node_.end() && IsInTree(it->second)) return false; + + const uint32_t old_size = GetSize(); + (void)old_size; + + InsertNode(CreateNode(next_timestamp_++, value)); + + last_accessed_value_ = value; + last_accessed_value_valid_ = true; + + assert(value_to_node_.count(value)); + assert(old_size + 1 == GetSize()); + return true; +} + +bool MoveToFront::Remove(uint32_t value) { + auto it = value_to_node_.find(value); + if (it == value_to_node_.end()) return false; + + if (!IsInTree(it->second)) return false; + + if (last_accessed_value_ == value) last_accessed_value_valid_ = false; + + const uint32_t orphan = RemoveNode(it->second); + (void)orphan; + // The node of |value| is still alive but it's orphaned now. Can still be + // reused later. + assert(!IsInTree(orphan)); + assert(ValueOf(orphan) == value); + return true; +} + +bool MoveToFront::RankFromValue(uint32_t value, uint32_t* rank) { + if (last_accessed_value_valid_ && last_accessed_value_ == value) { + *rank = 1; + return true; + } + + const uint32_t old_size = GetSize(); + if (old_size == 1) { + if (ValueOf(root_) == value) { + *rank = 1; + return true; + } else { + return false; + } + } + + const auto it = value_to_node_.find(value); + if (it == value_to_node_.end()) { + return false; + } + + uint32_t target = it->second; + + if (!IsInTree(target)) { + return false; + } + + uint32_t node = target; + *rank = 1 + SizeOf(LeftOf(node)); + while (node) { + if (IsRightChild(node)) *rank += 1 + SizeOf(LeftOf(ParentOf(node))); + node = ParentOf(node); + } + + // Don't update timestamp if the node has rank 1. + if (*rank != 1) { + // Update timestamp and reposition the node. + target = RemoveNode(target); + assert(ValueOf(target) == value); + assert(old_size == GetSize() + 1); + MutableTimestampOf(target) = next_timestamp_++; + InsertNode(target); + assert(old_size == GetSize()); + } + + last_accessed_value_ = value; + last_accessed_value_valid_ = true; + return true; +} + +bool MoveToFront::HasValue(uint32_t value) const { + const auto it = value_to_node_.find(value); + if (it == value_to_node_.end()) { + return false; + } + + return IsInTree(it->second); +} + +bool MoveToFront::Promote(uint32_t value) { + if (last_accessed_value_valid_ && last_accessed_value_ == value) { + return true; + } + + const uint32_t old_size = GetSize(); + if (old_size == 1) return ValueOf(root_) == value; + + const auto it = value_to_node_.find(value); + if (it == value_to_node_.end()) { + return false; + } + + uint32_t target = it->second; + + if (!IsInTree(target)) { + return false; + } + + // Update timestamp and reposition the node. + target = RemoveNode(target); + assert(ValueOf(target) == value); + assert(old_size == GetSize() + 1); + + MutableTimestampOf(target) = next_timestamp_++; + InsertNode(target); + assert(old_size == GetSize()); + + last_accessed_value_ = value; + last_accessed_value_valid_ = true; + return true; +} + +bool MoveToFront::ValueFromRank(uint32_t rank, uint32_t* value) { + if (last_accessed_value_valid_ && rank == 1) { + *value = last_accessed_value_; + return true; + } + + const uint32_t old_size = GetSize(); + if (rank <= 0 || rank > old_size) { + return false; + } + + if (old_size == 1) { + *value = ValueOf(root_); + return true; + } + + const bool update_timestamp = (rank != 1); + + uint32_t node = root_; + while (node) { + const uint32_t left_subtree_num_nodes = SizeOf(LeftOf(node)); + if (rank == left_subtree_num_nodes + 1) { + // This is the node we are looking for. + // Don't update timestamp if the node has rank 1. + if (update_timestamp) { + node = RemoveNode(node); + assert(old_size == GetSize() + 1); + MutableTimestampOf(node) = next_timestamp_++; + InsertNode(node); + assert(old_size == GetSize()); + } + *value = ValueOf(node); + last_accessed_value_ = *value; + last_accessed_value_valid_ = true; + return true; + } + + if (rank < left_subtree_num_nodes + 1) { + // Descend into the left subtree. The rank is still valid. + node = LeftOf(node); + } else { + // Descend into the right subtree. We leave behind the left subtree and + // the current node, adjust the |rank| accordingly. + rank -= left_subtree_num_nodes + 1; + node = RightOf(node); + } + } + + assert(0); + return false; +} + +uint32_t MoveToFront::CreateNode(uint32_t timestamp, uint32_t value) { + uint32_t handle = static_cast(nodes_.size()); + const auto result = value_to_node_.emplace(value, handle); + if (result.second) { + // Create new node. + nodes_.emplace_back(Node()); + Node& node = nodes_.back(); + node.timestamp = timestamp; + node.value = value; + node.size = 1; + // Non-NIL nodes start with height 1 because their NIL children are + // leaves. + node.height = 1; + } else { + // Reuse old node. + handle = result.first->second; + assert(!IsInTree(handle)); + assert(ValueOf(handle) == value); + assert(SizeOf(handle) == 1); + assert(HeightOf(handle) == 1); + MutableTimestampOf(handle) = timestamp; + } + + return handle; +} + +void MoveToFront::InsertNode(uint32_t node) { + assert(!IsInTree(node)); + assert(SizeOf(node) == 1); + assert(HeightOf(node) == 1); + assert(TimestampOf(node)); + + if (!root_) { + root_ = node; + return; + } + + uint32_t iter = root_; + uint32_t parent = 0; + + // Will determine if |node| will become the right or left child after + // insertion (but before balancing). + bool right_child = true; + + // Find the node which will become |node|'s parent after insertion + // (but before balancing). + while (iter) { + parent = iter; + assert(TimestampOf(iter) != TimestampOf(node)); + right_child = TimestampOf(iter) > TimestampOf(node); + iter = right_child ? RightOf(iter) : LeftOf(iter); + } + + assert(parent); + + // Connect node and parent. + MutableParentOf(node) = parent; + if (right_child) + MutableRightOf(parent) = node; + else + MutableLeftOf(parent) = node; + + // Insertion is finished. Start the balancing process. + bool needs_rebalancing = true; + parent = ParentOf(node); + + while (parent) { + UpdateNode(parent); + + if (needs_rebalancing) { + const int parent_balance = BalanceOf(parent); + + if (RightOf(parent) == node) { + // Added node to the right subtree. + if (parent_balance > 1) { + // Parent is right heavy, rotate left. + if (BalanceOf(node) < 0) RotateRight(node); + parent = RotateLeft(parent); + } else if (parent_balance == 0 || parent_balance == -1) { + // Parent is balanced or left heavy, no need to balance further. + needs_rebalancing = false; + } + } else { + // Added node to the left subtree. + if (parent_balance < -1) { + // Parent is left heavy, rotate right. + if (BalanceOf(node) > 0) RotateLeft(node); + parent = RotateRight(parent); + } else if (parent_balance == 0 || parent_balance == 1) { + // Parent is balanced or right heavy, no need to balance further. + needs_rebalancing = false; + } + } + } + + assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1)); + + node = parent; + parent = ParentOf(parent); + } +} + +uint32_t MoveToFront::RemoveNode(uint32_t node) { + if (LeftOf(node) && RightOf(node)) { + // If |node| has two children, then use another node as scapegoat and swap + // their contents. We pick the scapegoat on the side of the tree which has + // more nodes. + const uint32_t scapegoat = SizeOf(LeftOf(node)) >= SizeOf(RightOf(node)) + ? RightestDescendantOf(LeftOf(node)) + : LeftestDescendantOf(RightOf(node)); + assert(scapegoat); + std::swap(MutableValueOf(node), MutableValueOf(scapegoat)); + std::swap(MutableTimestampOf(node), MutableTimestampOf(scapegoat)); + value_to_node_[ValueOf(node)] = node; + value_to_node_[ValueOf(scapegoat)] = scapegoat; + node = scapegoat; + } + + // |node| may have only one child at this point. + assert(!RightOf(node) || !LeftOf(node)); + + uint32_t parent = ParentOf(node); + uint32_t child = RightOf(node) ? RightOf(node) : LeftOf(node); + + // Orphan |node| and reconnect parent and child. + if (child) MutableParentOf(child) = parent; + + if (parent) { + if (LeftOf(parent) == node) + MutableLeftOf(parent) = child; + else + MutableRightOf(parent) = child; + } + + MutableParentOf(node) = 0; + MutableLeftOf(node) = 0; + MutableRightOf(node) = 0; + UpdateNode(node); + const uint32_t orphan = node; + + if (root_ == node) root_ = child; + + // Removal is finished. Start the balancing process. + bool needs_rebalancing = true; + node = child; + + while (parent) { + UpdateNode(parent); + + if (needs_rebalancing) { + const int parent_balance = BalanceOf(parent); + + if (parent_balance == 1 || parent_balance == -1) { + // The height of the subtree was not changed. + needs_rebalancing = false; + } else { + if (RightOf(parent) == node) { + // Removed node from the right subtree. + if (parent_balance < -1) { + // Parent is left heavy, rotate right. + const uint32_t sibling = LeftOf(parent); + if (BalanceOf(sibling) > 0) RotateLeft(sibling); + parent = RotateRight(parent); + } + } else { + // Removed node from the left subtree. + if (parent_balance > 1) { + // Parent is right heavy, rotate left. + const uint32_t sibling = RightOf(parent); + if (BalanceOf(sibling) < 0) RotateRight(sibling); + parent = RotateLeft(parent); + } + } + } + } + + assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1)); + + node = parent; + parent = ParentOf(parent); + } + + return orphan; +} + +uint32_t MoveToFront::RotateLeft(const uint32_t node) { + const uint32_t pivot = RightOf(node); + assert(pivot); + + // LeftOf(pivot) gets attached to node in place of pivot. + MutableRightOf(node) = LeftOf(pivot); + if (RightOf(node)) MutableParentOf(RightOf(node)) = node; + + // Pivot gets attached to ParentOf(node) in place of node. + MutableParentOf(pivot) = ParentOf(node); + if (!ParentOf(node)) + root_ = pivot; + else if (IsLeftChild(node)) + MutableLeftOf(ParentOf(node)) = pivot; + else + MutableRightOf(ParentOf(node)) = pivot; + + // Node is child of pivot. + MutableLeftOf(pivot) = node; + MutableParentOf(node) = pivot; + + // Update both node and pivot. Pivot is the new parent of node, so node should + // be updated first. + UpdateNode(node); + UpdateNode(pivot); + + return pivot; +} + +uint32_t MoveToFront::RotateRight(const uint32_t node) { + const uint32_t pivot = LeftOf(node); + assert(pivot); + + // RightOf(pivot) gets attached to node in place of pivot. + MutableLeftOf(node) = RightOf(pivot); + if (LeftOf(node)) MutableParentOf(LeftOf(node)) = node; + + // Pivot gets attached to ParentOf(node) in place of node. + MutableParentOf(pivot) = ParentOf(node); + if (!ParentOf(node)) + root_ = pivot; + else if (IsLeftChild(node)) + MutableLeftOf(ParentOf(node)) = pivot; + else + MutableRightOf(ParentOf(node)) = pivot; + + // Node is child of pivot. + MutableRightOf(pivot) = node; + MutableParentOf(node) = pivot; + + // Update both node and pivot. Pivot is the new parent of node, so node should + // be updated first. + UpdateNode(node); + UpdateNode(pivot); + + return pivot; +} + +void MoveToFront::UpdateNode(uint32_t node) { + MutableSizeOf(node) = 1 + SizeOf(LeftOf(node)) + SizeOf(RightOf(node)); + MutableHeightOf(node) = + 1 + std::max(HeightOf(LeftOf(node)), HeightOf(RightOf(node))); +} + +} // namespace comp +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/comp/move_to_front.h b/3rdparty/spirv-tools/source/comp/move_to_front.h new file mode 100644 index 000000000..8752194ec --- /dev/null +++ b/3rdparty/spirv-tools/source/comp/move_to_front.h @@ -0,0 +1,384 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_COMP_MOVE_TO_FRONT_H_ +#define SOURCE_COMP_MOVE_TO_FRONT_H_ + +#include +#include +#include +#include +#include +#include + +namespace spvtools { +namespace comp { + +// Log(n) move-to-front implementation. Implements the following functions: +// Insert - pushes value to the front of the mtf sequence +// (only unique values allowed). +// Remove - remove value from the sequence. +// ValueFromRank - access value by its 1-indexed rank in the sequence. +// RankFromValue - get the rank of the given value in the sequence. +// Accessing a value with ValueFromRank or RankFromValue moves the value to the +// front of the sequence (rank of 1). +// +// The implementation is based on an AVL-based order statistic tree. The tree +// is ordered by timestamps issued when values are inserted or accessed (recent +// values go to the left side of the tree, old values are gradually rotated to +// the right side). +// +// Terminology +// rank: 1-indexed rank showing how recently the value was inserted or accessed. +// node: handle used internally to access node data. +// size: size of the subtree of a node (including the node). +// height: distance from a node to the farthest leaf. +class MoveToFront { + public: + explicit MoveToFront(size_t reserve_capacity = 4) { + nodes_.reserve(reserve_capacity); + + // Create NIL node. + nodes_.emplace_back(Node()); + } + + virtual ~MoveToFront() = default; + + // Inserts value in the move-to-front sequence. Does nothing if the value is + // already in the sequence. Returns true if insertion was successful. + // The inserted value is placed at the front of the sequence (rank 1). + bool Insert(uint32_t value); + + // Removes value from move-to-front sequence. Returns false iff the value + // was not found. + bool Remove(uint32_t value); + + // Computes 1-indexed rank of value in the move-to-front sequence and moves + // the value to the front. Example: + // Before the call: 4 8 2 1 7 + // RankFromValue(8) returns 2 + // After the call: 8 4 2 1 7 + // Returns true iff the value was found in the sequence. + bool RankFromValue(uint32_t value, uint32_t* rank); + + // Returns value corresponding to a 1-indexed rank in the move-to-front + // sequence and moves the value to the front. Example: + // Before the call: 4 8 2 1 7 + // ValueFromRank(2) returns 8 + // After the call: 8 4 2 1 7 + // Returns true iff the rank is within bounds [1, GetSize()]. + bool ValueFromRank(uint32_t rank, uint32_t* value); + + // Moves the value to the front of the sequence. + // Returns false iff value is not in the sequence. + bool Promote(uint32_t value); + + // Returns true iff the move-to-front sequence contains the value. + bool HasValue(uint32_t value) const; + + // Returns the number of elements in the move-to-front sequence. + uint32_t GetSize() const { return SizeOf(root_); } + + protected: + // Internal tree data structure uses handles instead of pointers. Leaves and + // root parent reference a singleton under handle 0. Although dereferencing + // a null pointer is not possible, inappropriate access to handle 0 would + // cause an assertion. Handles are not garbage collected if value was + // deprecated + // with DeprecateValue(). But handles are recycled when a node is + // repositioned. + + // Internal tree data structure node. + struct Node { + // Timestamp from a logical clock which updates every time the element is + // accessed through ValueFromRank or RankFromValue. + uint32_t timestamp = 0; + // The size of the node's subtree, including the node. + // SizeOf(LeftOf(node)) + SizeOf(RightOf(node)) + 1. + uint32_t size = 0; + // Handles to connected nodes. + uint32_t left = 0; + uint32_t right = 0; + uint32_t parent = 0; + // Distance to the farthest leaf. + // Leaves have height 0, real nodes at least 1. + uint32_t height = 0; + // Stored value. + uint32_t value = 0; + }; + + // Creates node and sets correct values. Non-NIL nodes should be created only + // through this function. If the node with this value has been created + // previously + // and since orphaned, reuses the old node instead of creating a new one. + uint32_t CreateNode(uint32_t timestamp, uint32_t value); + + // Node accessor methods. Naming is designed to be similar to natural + // language as these functions tend to be used in sequences, for example: + // ParentOf(LeftestDescendentOf(RightOf(node))) + + // Returns value of the node referenced by |handle|. + uint32_t ValueOf(uint32_t node) const { return nodes_.at(node).value; } + + // Returns left child of |node|. + uint32_t LeftOf(uint32_t node) const { return nodes_.at(node).left; } + + // Returns right child of |node|. + uint32_t RightOf(uint32_t node) const { return nodes_.at(node).right; } + + // Returns parent of |node|. + uint32_t ParentOf(uint32_t node) const { return nodes_.at(node).parent; } + + // Returns timestamp of |node|. + uint32_t TimestampOf(uint32_t node) const { + assert(node); + return nodes_.at(node).timestamp; + } + + // Returns size of |node|. + uint32_t SizeOf(uint32_t node) const { return nodes_.at(node).size; } + + // Returns height of |node|. + uint32_t HeightOf(uint32_t node) const { return nodes_.at(node).height; } + + // Returns mutable reference to value of |node|. + uint32_t& MutableValueOf(uint32_t node) { + assert(node); + return nodes_.at(node).value; + } + + // Returns mutable reference to handle of left child of |node|. + uint32_t& MutableLeftOf(uint32_t node) { + assert(node); + return nodes_.at(node).left; + } + + // Returns mutable reference to handle of right child of |node|. + uint32_t& MutableRightOf(uint32_t node) { + assert(node); + return nodes_.at(node).right; + } + + // Returns mutable reference to handle of parent of |node|. + uint32_t& MutableParentOf(uint32_t node) { + assert(node); + return nodes_.at(node).parent; + } + + // Returns mutable reference to timestamp of |node|. + uint32_t& MutableTimestampOf(uint32_t node) { + assert(node); + return nodes_.at(node).timestamp; + } + + // Returns mutable reference to size of |node|. + uint32_t& MutableSizeOf(uint32_t node) { + assert(node); + return nodes_.at(node).size; + } + + // Returns mutable reference to height of |node|. + uint32_t& MutableHeightOf(uint32_t node) { + assert(node); + return nodes_.at(node).height; + } + + // Returns true iff |node| is left child of its parent. + bool IsLeftChild(uint32_t node) const { + assert(node); + return LeftOf(ParentOf(node)) == node; + } + + // Returns true iff |node| is right child of its parent. + bool IsRightChild(uint32_t node) const { + assert(node); + return RightOf(ParentOf(node)) == node; + } + + // Returns true iff |node| has no relatives. + bool IsOrphan(uint32_t node) const { + assert(node); + return !ParentOf(node) && !LeftOf(node) && !RightOf(node); + } + + // Returns true iff |node| is in the tree. + bool IsInTree(uint32_t node) const { + assert(node); + return node == root_ || !IsOrphan(node); + } + + // Returns the height difference between right and left subtrees. + int BalanceOf(uint32_t node) const { + return int(HeightOf(RightOf(node))) - int(HeightOf(LeftOf(node))); + } + + // Updates size and height of the node, assuming that the children have + // correct values. + void UpdateNode(uint32_t node); + + // Returns the most LeftOf(LeftOf(... descendent which is not leaf. + uint32_t LeftestDescendantOf(uint32_t node) const { + uint32_t parent = 0; + while (node) { + parent = node; + node = LeftOf(node); + } + return parent; + } + + // Returns the most RightOf(RightOf(... descendent which is not leaf. + uint32_t RightestDescendantOf(uint32_t node) const { + uint32_t parent = 0; + while (node) { + parent = node; + node = RightOf(node); + } + return parent; + } + + // Inserts node in the tree. The node must be an orphan. + void InsertNode(uint32_t node); + + // Removes node from the tree. May change value_to_node_ if removal uses a + // scapegoat. Returns the removed (orphaned) handle for recycling. The + // returned handle may not be equal to |node| if scapegoat was used. + uint32_t RemoveNode(uint32_t node); + + // Rotates |node| left, reassigns all connections and returns the node + // which takes place of the |node|. + uint32_t RotateLeft(const uint32_t node); + + // Rotates |node| right, reassigns all connections and returns the node + // which takes place of the |node|. + uint32_t RotateRight(const uint32_t node); + + // Root node handle. The tree is empty if root_ is 0. + uint32_t root_ = 0; + + // Incremented counters for next timestamp and value. + uint32_t next_timestamp_ = 1; + + // Holds all tree nodes. Indices of this vector are node handles. + std::vector nodes_; + + // Maps ids to node handles. + std::unordered_map value_to_node_; + + // Cache for the last accessed value in the sequence. + uint32_t last_accessed_value_ = 0; + bool last_accessed_value_valid_ = false; +}; + +class MultiMoveToFront { + public: + // Inserts |value| to sequence with handle |mtf|. + // Returns false if |mtf| already has |value|. + bool Insert(uint64_t mtf, uint32_t value) { + if (GetMtf(mtf).Insert(value)) { + val_to_mtfs_[value].insert(mtf); + return true; + } + return false; + } + + // Removes |value| from sequence with handle |mtf|. + // Returns false if |mtf| doesn't have |value|. + bool Remove(uint64_t mtf, uint32_t value) { + if (GetMtf(mtf).Remove(value)) { + val_to_mtfs_[value].erase(mtf); + return true; + } + assert(val_to_mtfs_[value].count(mtf) == 0); + return false; + } + + // Removes |value| from all sequences which have it. + void RemoveFromAll(uint32_t value) { + auto it = val_to_mtfs_.find(value); + if (it == val_to_mtfs_.end()) return; + + auto& mtfs_containing_value = it->second; + for (uint64_t mtf : mtfs_containing_value) { + GetMtf(mtf).Remove(value); + } + + val_to_mtfs_.erase(value); + } + + // Computes rank of |value| in sequence |mtf|. + // Returns false if |mtf| doesn't have |value|. + bool RankFromValue(uint64_t mtf, uint32_t value, uint32_t* rank) { + return GetMtf(mtf).RankFromValue(value, rank); + } + + // Finds |value| with |rank| in sequence |mtf|. + // Returns false if |rank| is out of bounds. + bool ValueFromRank(uint64_t mtf, uint32_t rank, uint32_t* value) { + return GetMtf(mtf).ValueFromRank(rank, value); + } + + // Returns size of |mtf| sequence. + uint32_t GetSize(uint64_t mtf) { return GetMtf(mtf).GetSize(); } + + // Promotes |value| in all sequences which have it. + void Promote(uint32_t value) { + const auto it = val_to_mtfs_.find(value); + if (it == val_to_mtfs_.end()) return; + + const auto& mtfs_containing_value = it->second; + for (uint64_t mtf : mtfs_containing_value) { + GetMtf(mtf).Promote(value); + } + } + + // Inserts |value| in sequence |mtf| or promotes if it's already there. + void InsertOrPromote(uint64_t mtf, uint32_t value) { + if (!Insert(mtf, value)) { + GetMtf(mtf).Promote(value); + } + } + + // Returns if |mtf| sequence has |value|. + bool HasValue(uint64_t mtf, uint32_t value) { + return GetMtf(mtf).HasValue(value); + } + + private: + // Returns actual MoveToFront object corresponding to |handle|. + // As multiple operations are often performed consecutively for the same + // sequence, the last returned value is cached. + MoveToFront& GetMtf(uint64_t handle) { + if (!cached_mtf_ || cached_handle_ != handle) { + cached_handle_ = handle; + cached_mtf_ = &mtfs_[handle]; + } + + return *cached_mtf_; + } + + // Container holding MoveToFront objects. Map key is sequence handle. + std::map mtfs_; + + // Container mapping value to sequences which contain that value. + std::unordered_map> val_to_mtfs_; + + // Cache for the last accessed sequence. + uint64_t cached_handle_ = 0; + MoveToFront* cached_mtf_ = nullptr; +}; + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_MOVE_TO_FRONT_H_ diff --git a/3rdparty/spirv-tools/source/diagnostic.cpp b/3rdparty/spirv-tools/source/diagnostic.cpp index 578712069..edc27c8fd 100644 --- a/3rdparty/spirv-tools/source/diagnostic.cpp +++ b/3rdparty/spirv-tools/source/diagnostic.cpp @@ -12,14 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "diagnostic.h" +#include "source/diagnostic.h" #include #include #include #include +#include -#include "table.h" +#include "source/table.h" // Diagnostic API @@ -57,20 +58,23 @@ spv_result_t spvDiagnosticPrint(const spv_diagnostic diagnostic) { << diagnostic->position.column + 1 << ": " << diagnostic->error << "\n"; return SPV_SUCCESS; - } else { - // NOTE: Assume this is a binary position - std::cerr << "error: " << diagnostic->position.index << ": " - << diagnostic->error << "\n"; - return SPV_SUCCESS; } + + // NOTE: Assume this is a binary position + std::cerr << "error: "; + if (diagnostic->position.index > 0) + std::cerr << diagnostic->position.index << ": "; + std::cerr << diagnostic->error << "\n"; + return SPV_SUCCESS; } -namespace libspirv { +namespace spvtools { DiagnosticStream::DiagnosticStream(DiagnosticStream&& other) : stream_(), position_(other.position_), consumer_(other.consumer_), + disassembled_instruction_(std::move(other.disassembled_instruction_)), error_(other.error_) { // Prevent the other object from emitting output during destruction. other.error_ = SPV_FAILED_MATCH; @@ -102,6 +106,9 @@ DiagnosticStream::~DiagnosticStream() { default: break; } + if (disassembled_instruction_.size() > 0) + stream_ << std::endl << " " << disassembled_instruction_ << std::endl; + consumer_(level, "input", position_, stream_.str().c_str()); } } @@ -117,7 +124,7 @@ void UseDiagnosticAsMessageConsumer(spv_context context, spvDiagnosticDestroy(*diagnostic); // Avoid memory leak. *diagnostic = spvDiagnosticCreate(&p, message); }; - libspirv::SetContextMessageConsumer(context, std::move(create_diagnostic)); + SetContextMessageConsumer(context, std::move(create_diagnostic)); } std::string spvResultToString(spv_result_t res) { @@ -183,4 +190,4 @@ std::string spvResultToString(spv_result_t res) { return out; } -} // namespace libspirv +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/diagnostic.h b/3rdparty/spirv-tools/source/diagnostic.h index 3d06f8df0..22df96143 100644 --- a/3rdparty/spirv-tools/source/diagnostic.h +++ b/3rdparty/spirv-tools/source/diagnostic.h @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_DIAGNOSTIC_H_ -#define LIBSPIRV_DIAGNOSTIC_H_ +#ifndef SOURCE_DIAGNOSTIC_H_ +#define SOURCE_DIAGNOSTIC_H_ #include #include #include "spirv-tools/libspirv.hpp" -namespace libspirv { +namespace spvtools { // A DiagnosticStream remembers the current position of the input and an error // code, and captures diagnostic messages via the left-shift operator. @@ -28,10 +28,13 @@ namespace libspirv { // emitted during the destructor. class DiagnosticStream { public: - DiagnosticStream(spv_position_t position, - const spvtools::MessageConsumer& consumer, + DiagnosticStream(spv_position_t position, const MessageConsumer& consumer, + const std::string& disassembled_instruction, spv_result_t error) - : position_(position), consumer_(consumer), error_(error) {} + : position_(position), + consumer_(consumer), + disassembled_instruction_(disassembled_instruction), + error_(error) {} // Creates a DiagnosticStream from an expiring DiagnosticStream. // The new object takes the contents of the other, and prevents the @@ -56,7 +59,8 @@ class DiagnosticStream { private: std::ostringstream stream_; spv_position_t position_; - spvtools::MessageConsumer consumer_; // Message consumer callback. + MessageConsumer consumer_; // Message consumer callback. + std::string disassembled_instruction_; spv_result_t error_; }; @@ -70,6 +74,6 @@ void UseDiagnosticAsMessageConsumer(spv_context context, std::string spvResultToString(spv_result_t res); -} // namespace libspirv +} // namespace spvtools -#endif // LIBSPIRV_DIAGNOSTIC_H_ +#endif // SOURCE_DIAGNOSTIC_H_ diff --git a/3rdparty/spirv-tools/source/disassemble.cpp b/3rdparty/spirv-tools/source/disassemble.cpp index 909886c09..c116f5072 100644 --- a/3rdparty/spirv-tools/source/disassemble.cpp +++ b/3rdparty/spirv-tools/source/disassemble.cpp @@ -21,20 +21,22 @@ #include #include #include +#include -#include "assembly_grammar.h" -#include "binary.h" -#include "diagnostic.h" -#include "disassemble.h" -#include "ext_inst.h" -#include "name_mapper.h" -#include "opcode.h" -#include "parsed_operand.h" -#include "print.h" +#include "source/assembly_grammar.h" +#include "source/binary.h" +#include "source/diagnostic.h" +#include "source/disassemble.h" +#include "source/ext_inst.h" +#include "source/name_mapper.h" +#include "source/opcode.h" +#include "source/parsed_operand.h" +#include "source/print.h" +#include "source/spirv_constant.h" +#include "source/spirv_endian.h" +#include "source/util/hex_float.h" +#include "source/util/make_unique.h" #include "spirv-tools/libspirv.h" -#include "spirv_constant.h" -#include "spirv_endian.h" -#include "util/hex_float.h" namespace { @@ -42,8 +44,8 @@ namespace { // representation. class Disassembler { public: - Disassembler(const libspirv::AssemblyGrammar& grammar, uint32_t options, - libspirv::NameMapper name_mapper) + Disassembler(const spvtools::AssemblyGrammar& grammar, uint32_t options, + spvtools::NameMapper name_mapper) : grammar_(grammar), print_(spvIsInBitfield(SPV_BINARY_TO_TEXT_OPTION_PRINT, options)), color_(spvIsInBitfield(SPV_BINARY_TO_TEXT_OPTION_COLOR, options)), @@ -75,7 +77,7 @@ class Disassembler { private: enum { kStandardIndent = 15 }; - using out_stream = libspirv::out_stream; + using out_stream = spvtools::out_stream; // Emits an operand for the given instruction, where the instruction // is at offset words from the start of the binary. @@ -87,30 +89,30 @@ class Disassembler { // Resets the output color, if color is turned on. void ResetColor() { - if (color_) out_.get() << libspirv::clr::reset{print_}; + if (color_) out_.get() << spvtools::clr::reset{print_}; } // Sets the output to grey, if color is turned on. void SetGrey() { - if (color_) out_.get() << libspirv::clr::grey{print_}; + if (color_) out_.get() << spvtools::clr::grey{print_}; } // Sets the output to blue, if color is turned on. void SetBlue() { - if (color_) out_.get() << libspirv::clr::blue{print_}; + if (color_) out_.get() << spvtools::clr::blue{print_}; } // Sets the output to yellow, if color is turned on. void SetYellow() { - if (color_) out_.get() << libspirv::clr::yellow{print_}; + if (color_) out_.get() << spvtools::clr::yellow{print_}; } // Sets the output to red, if color is turned on. void SetRed() { - if (color_) out_.get() << libspirv::clr::red{print_}; + if (color_) out_.get() << spvtools::clr::red{print_}; } // Sets the output to green, if color is turned on. void SetGreen() { - if (color_) out_.get() << libspirv::clr::green{print_}; + if (color_) out_.get() << spvtools::clr::green{print_}; } - const libspirv::AssemblyGrammar& grammar_; + const spvtools::AssemblyGrammar& grammar_; const bool print_; // Should we also print to the standard output stream? const bool color_; // Should we print in colour? const int indent_; // How much to indent. 0 means don't indent @@ -121,7 +123,7 @@ class Disassembler { const bool header_; // Should we output header as the leading comment? const bool show_byte_offset_; // Should we print byte offset, in hex? size_t byte_offset_; // The number of bytes processed so far. - libspirv::NameMapper name_mapper_; + spvtools::NameMapper name_mapper_; }; spv_result_t Disassembler::HandleHeader(spv_endianness_t endian, @@ -230,7 +232,7 @@ void Disassembler::EmitOperand(const spv_parsed_instruction_t& inst, case SPV_OPERAND_TYPE_LITERAL_INTEGER: case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: { SetRed(); - libspirv::EmitNumericLiteral(&stream_, inst, operand); + spvtools::EmitNumericLiteral(&stream_, inst, operand); ResetColor(); } break; case SPV_OPERAND_TYPE_LITERAL_STRING: { @@ -399,7 +401,7 @@ spv_result_t DisassembleTargetInstruction( return SPV_SUCCESS; } -} // anonymous namespace +} // namespace spv_result_t spvBinaryToText(const spv_const_context context, const uint32_t* code, const size_t wordCount, @@ -408,18 +410,18 @@ spv_result_t spvBinaryToText(const spv_const_context context, spv_context_t hijack_context = *context; if (pDiagnostic) { *pDiagnostic = nullptr; - libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); } - const libspirv::AssemblyGrammar grammar(&hijack_context); + const spvtools::AssemblyGrammar grammar(&hijack_context); if (!grammar.isValid()) return SPV_ERROR_INVALID_TABLE; // Generate friendly names for Ids if requested. - std::unique_ptr friendly_mapper; - libspirv::NameMapper name_mapper = libspirv::GetTrivialNameMapper(); + std::unique_ptr friendly_mapper; + spvtools::NameMapper name_mapper = spvtools::GetTrivialNameMapper(); if (options & SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES) { - friendly_mapper.reset( - new libspirv::FriendlyNameMapper(&hijack_context, code, wordCount)); + friendly_mapper = spvtools::MakeUnique( + &hijack_context, code, wordCount); name_mapper = friendly_mapper->GetNameMapper(); } @@ -441,18 +443,18 @@ std::string spvtools::spvInstructionBinaryToText(const spv_target_env env, const size_t wordCount, const uint32_t options) { spv_context context = spvContextCreate(env); - const libspirv::AssemblyGrammar grammar(context); + const spvtools::AssemblyGrammar grammar(context); if (!grammar.isValid()) { spvContextDestroy(context); return ""; } // Generate friendly names for Ids if requested. - std::unique_ptr friendly_mapper; - libspirv::NameMapper name_mapper = libspirv::GetTrivialNameMapper(); + std::unique_ptr friendly_mapper; + spvtools::NameMapper name_mapper = spvtools::GetTrivialNameMapper(); if (options & SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES) { - friendly_mapper.reset( - new libspirv::FriendlyNameMapper(context, code, wordCount)); + friendly_mapper = spvtools::MakeUnique( + context, code, wordCount); name_mapper = friendly_mapper->GetNameMapper(); } diff --git a/3rdparty/spirv-tools/source/disassemble.h b/3rdparty/spirv-tools/source/disassemble.h index b833dd07a..ac3574272 100644 --- a/3rdparty/spirv-tools/source/disassemble.h +++ b/3rdparty/spirv-tools/source/disassemble.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SPIRV_TOOLS_DISASSEMBLE_H_ -#define SPIRV_TOOLS_DISASSEMBLE_H_ +#ifndef SOURCE_DISASSEMBLE_H_ +#define SOURCE_DISASSEMBLE_H_ #include @@ -35,4 +35,4 @@ std::string spvInstructionBinaryToText(const spv_target_env env, } // namespace spvtools -#endif // SPIRV_TOOLS_DISASSEMBLE_H_ +#endif // SOURCE_DISASSEMBLE_H_ diff --git a/3rdparty/spirv-tools/source/enum_set.h b/3rdparty/spirv-tools/source/enum_set.h index 75a49f06b..e4ef297cd 100644 --- a/3rdparty/spirv-tools/source/enum_set.h +++ b/3rdparty/spirv-tools/source/enum_set.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_ENUM_SET_H -#define LIBSPIRV_ENUM_SET_H +#ifndef SOURCE_ENUM_SET_H_ +#define SOURCE_ENUM_SET_H_ #include #include @@ -21,9 +21,10 @@ #include #include -#include "latest_version_spirv_header.h" +#include "source/latest_version_spirv_header.h" +#include "source/util/make_unique.h" -namespace libspirv { +namespace spvtools { // A set of values of a 32-bit enum type. // It is fast and compact for the common case, where enum values @@ -152,7 +153,7 @@ class EnumSet { // allocated if one doesn't exist yet. Returns overflow_set_. OverflowSetType& Overflow() { if (overflow_.get() == nullptr) { - overflow_.reset(new OverflowSetType); + overflow_ = MakeUnique(); } return *overflow_; } @@ -167,6 +168,6 @@ class EnumSet { // A set of SpvCapability, optimized for small capability values. using CapabilitySet = EnumSet; -} // namespace libspirv +} // namespace spvtools -#endif // LIBSPIRV_ENUM_SET_H +#endif // SOURCE_ENUM_SET_H_ diff --git a/3rdparty/spirv-tools/source/enum_string_mapping.cpp b/3rdparty/spirv-tools/source/enum_string_mapping.cpp index e993b5842..32361a08d 100644 --- a/3rdparty/spirv-tools/source/enum_string_mapping.cpp +++ b/3rdparty/spirv-tools/source/enum_string_mapping.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "enum_string_mapping.h" +#include "source/enum_string_mapping.h" #include #include @@ -20,10 +20,10 @@ #include #include -#include "extensions.h" +#include "source/extensions.h" -namespace libspirv { +namespace spvtools { #include "enum_string_mapping.inc" -} // namespace libspirv +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/enum_string_mapping.h b/3rdparty/spirv-tools/source/enum_string_mapping.h index 4b126810a..af8f56b82 100644 --- a/3rdparty/spirv-tools/source/enum_string_mapping.h +++ b/3rdparty/spirv-tools/source/enum_string_mapping.h @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_ENUM_STRING_MAPPING_H_ -#define LIBSPIRV_ENUM_STRING_MAPPING_H_ +#ifndef SOURCE_ENUM_STRING_MAPPING_H_ +#define SOURCE_ENUM_STRING_MAPPING_H_ #include -#include "extensions.h" -#include "latest_version_spirv_header.h" +#include "source/extensions.h" +#include "source/latest_version_spirv_header.h" -namespace libspirv { +namespace spvtools { // Finds Extension enum corresponding to |str|. Returns false if not found. bool GetExtensionFromString(const char* str, Extension* extension); @@ -31,6 +31,6 @@ const char* ExtensionToString(Extension extension); // Returns text string corresponding to |capability|. const char* CapabilityToString(SpvCapability capability); -} // namespace libspirv +} // namespace spvtools -#endif // LIBSPIRV_ENUM_STRING_MAPPING_H_ +#endif // SOURCE_ENUM_STRING_MAPPING_H_ diff --git a/3rdparty/spirv-tools/source/ext_inst.cpp b/3rdparty/spirv-tools/source/ext_inst.cpp index 6218eb115..a4c00c2ff 100644 --- a/3rdparty/spirv-tools/source/ext_inst.cpp +++ b/3rdparty/spirv-tools/source/ext_inst.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "ext_inst.h" +#include "source/ext_inst.h" #include #include @@ -22,14 +22,14 @@ // TODO(dneto): DebugInfo.h should probably move to SPIRV-Headers. #include "DebugInfo.h" -#include "latest_version_glsl_std_450_header.h" -#include "latest_version_opencl_std_header.h" -#include "macro.h" -#include "spirv_definition.h" +#include "source/latest_version_glsl_std_450_header.h" +#include "source/latest_version_opencl_std_header.h" +#include "source/macro.h" +#include "source/spirv_definition.h" -#include "debuginfo.insts.inc" // defines opencl_entries -#include "glsl.std.450.insts.inc" // defines glsl_entries -#include "opencl.std.insts.inc" // defines opencl_entries +#include "debuginfo.insts.inc" +#include "glsl.std.450.insts.inc" +#include "opencl.std.insts.inc" #include "spv-amd-gcn-shader.insts.inc" #include "spv-amd-shader-ballot.insts.inc" @@ -81,6 +81,7 @@ spv_result_t spvExtInstTableGet(spv_ext_inst_table* pExtInstTable, case SPV_ENV_OPENGL_4_5: case SPV_ENV_UNIVERSAL_1_3: case SPV_ENV_VULKAN_1_1: + case SPV_ENV_WEBGPU_0: *pExtInstTable = &kTable_1_0; return SPV_SUCCESS; default: diff --git a/3rdparty/spirv-tools/source/ext_inst.h b/3rdparty/spirv-tools/source/ext_inst.h index 1a16b4b2d..a821cc2bc 100644 --- a/3rdparty/spirv-tools/source/ext_inst.h +++ b/3rdparty/spirv-tools/source/ext_inst.h @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_EXT_INST_H_ -#define LIBSPIRV_EXT_INST_H_ +#ifndef SOURCE_EXT_INST_H_ +#define SOURCE_EXT_INST_H_ +#include "source/table.h" #include "spirv-tools/libspirv.h" -#include "table.h" // Gets the type of the extended instruction set with the specified name. spv_ext_inst_type_t spvExtInstImportTypeGet(const char* name); @@ -37,4 +37,4 @@ spv_result_t spvExtInstTableValueLookup(const spv_ext_inst_table table, const uint32_t value, spv_ext_inst_desc* pEntry); -#endif // LIBSPIRV_EXT_INST_H_ +#endif // SOURCE_EXT_INST_H_ diff --git a/3rdparty/spirv-tools/source/extensions.cpp b/3rdparty/spirv-tools/source/extensions.cpp index 065543c11..a94db273e 100644 --- a/3rdparty/spirv-tools/source/extensions.cpp +++ b/3rdparty/spirv-tools/source/extensions.cpp @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "extensions.h" +#include "source/extensions.h" #include #include #include -#include "enum_string_mapping.h" +#include "source/enum_string_mapping.h" -namespace libspirv { +namespace spvtools { std::string GetExtensionString(const spv_parsed_instruction_t* inst) { if (inst->opcode != SpvOpExtension) return "ERROR_not_op_extension"; @@ -41,4 +41,4 @@ std::string ExtensionSetToString(const ExtensionSet& extensions) { return ss.str(); } -} // namespace libspirv +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/extensions.h b/3rdparty/spirv-tools/source/extensions.h index 9947f1a9b..8023444c3 100644 --- a/3rdparty/spirv-tools/source/extensions.h +++ b/3rdparty/spirv-tools/source/extensions.h @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_EXTENSIONS_H_ -#define LIBSPIRV_EXTENSIONS_H_ +#ifndef SOURCE_EXTENSIONS_H_ +#define SOURCE_EXTENSIONS_H_ #include -#include "enum_set.h" +#include "source/enum_set.h" #include "spirv-tools/libspirv.h" -namespace libspirv { +namespace spvtools { // The known SPIR-V extensions. enum Extension { @@ -35,6 +35,6 @@ std::string GetExtensionString(const spv_parsed_instruction_t* inst); // Returns text string listing |extensions| separated by whitespace. std::string ExtensionSetToString(const ExtensionSet& extensions); -} // namespace libspirv +} // namespace spvtools -#endif // LIBSPIRV_EXTENSIONS_H_ +#endif // SOURCE_EXTENSIONS_H_ diff --git a/3rdparty/spirv-tools/source/id_descriptor.cpp b/3rdparty/spirv-tools/source/id_descriptor.cpp index 7697c2993..d44ed672c 100644 --- a/3rdparty/spirv-tools/source/id_descriptor.cpp +++ b/3rdparty/spirv-tools/source/id_descriptor.cpp @@ -12,16 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "id_descriptor.h" +#include "source/id_descriptor.h" #include #include -#include "opcode.h" -#include "operand.h" - -namespace libspirv { +#include "source/opcode.h" +#include "source/operand.h" +namespace spvtools { namespace { // Hashes an array of words. Order of words is important. @@ -76,4 +75,4 @@ uint32_t IdDescriptorCollection::ProcessInstruction( return descriptor; } -} // namespace libspirv +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/id_descriptor.h b/3rdparty/spirv-tools/source/id_descriptor.h index d27123a1e..add23343a 100644 --- a/3rdparty/spirv-tools/source/id_descriptor.h +++ b/3rdparty/spirv-tools/source/id_descriptor.h @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_ID_DESCRIPTOR_H_ -#define LIBSPIRV_ID_DESCRIPTOR_H_ +#ifndef SOURCE_ID_DESCRIPTOR_H_ +#define SOURCE_ID_DESCRIPTOR_H_ #include #include #include "spirv-tools/libspirv.hpp" -namespace libspirv { +namespace spvtools { using CustomHashFunc = std::function&)>; @@ -58,6 +58,6 @@ class IdDescriptorCollection { std::vector words_; }; -} // namespace libspirv +} // namespace spvtools -#endif // LIBSPIRV_ID_DESCRIPTOR_H_ +#endif // SOURCE_ID_DESCRIPTOR_H_ diff --git a/3rdparty/spirv-tools/source/instruction.h b/3rdparty/spirv-tools/source/instruction.h index 884276d7e..9e7dccd03 100644 --- a/3rdparty/spirv-tools/source/instruction.h +++ b/3rdparty/spirv-tools/source/instruction.h @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_INSTRUCTION_H_ -#define LIBSPIRV_INSTRUCTION_H_ +#ifndef SOURCE_INSTRUCTION_H_ +#define SOURCE_INSTRUCTION_H_ #include #include -#include "latest_version_spirv_header.h" +#include "source/latest_version_spirv_header.h" #include "spirv-tools/libspirv.h" // Describes an instruction. @@ -46,4 +46,4 @@ inline void spvInstructionAddWord(spv_instruction_t* inst, uint32_t value) { inst->words.push_back(value); } -#endif // LIBSPIRV_INSTRUCTION_H_ +#endif // SOURCE_INSTRUCTION_H_ diff --git a/3rdparty/spirv-tools/source/latest_version_glsl_std_450_header.h b/3rdparty/spirv-tools/source/latest_version_glsl_std_450_header.h index b9e9ae275..bed1f2502 100644 --- a/3rdparty/spirv-tools/source/latest_version_glsl_std_450_header.h +++ b/3rdparty/spirv-tools/source/latest_version_glsl_std_450_header.h @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_LATEST_VERSION_GLSL_STD_450_HEADER_H_ -#define LIBSPIRV_LATEST_VERSION_GLSL_STD_450_HEADER_H_ +#ifndef SOURCE_LATEST_VERSION_GLSL_STD_450_HEADER_H_ +#define SOURCE_LATEST_VERSION_GLSL_STD_450_HEADER_H_ #include "spirv/unified1/GLSL.std.450.h" -#endif // LIBSPIRV_LATEST_VERSION_GLSL_STD_450_HEADER_H_ +#endif // SOURCE_LATEST_VERSION_GLSL_STD_450_HEADER_H_ diff --git a/3rdparty/spirv-tools/source/latest_version_opencl_std_header.h b/3rdparty/spirv-tools/source/latest_version_opencl_std_header.h index 9bb6e5028..90ff9c033 100644 --- a/3rdparty/spirv-tools/source/latest_version_opencl_std_header.h +++ b/3rdparty/spirv-tools/source/latest_version_opencl_std_header.h @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_LATEST_VERSION_OPENCL_STD_HEADER_H_ -#define LIBSPIRV_LATEST_VERSION_OPENCL_STD_HEADER_H_ +#ifndef SOURCE_LATEST_VERSION_OPENCL_STD_HEADER_H_ +#define SOURCE_LATEST_VERSION_OPENCL_STD_HEADER_H_ #include "spirv/unified1/OpenCL.std.h" -#endif // LIBSPIRV_LATEST_VERSION_OPENCL_STD_HEADER_H_ +#endif // SOURCE_LATEST_VERSION_OPENCL_STD_HEADER_H_ diff --git a/3rdparty/spirv-tools/source/latest_version_spirv_header.h b/3rdparty/spirv-tools/source/latest_version_spirv_header.h index c328b6976..e4f28e43e 100644 --- a/3rdparty/spirv-tools/source/latest_version_spirv_header.h +++ b/3rdparty/spirv-tools/source/latest_version_spirv_header.h @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_LATEST_VERSION_SPIRV_HEADER_H_ -#define LIBSPIRV_LATEST_VERSION_SPIRV_HEADER_H_ +#ifndef SOURCE_LATEST_VERSION_SPIRV_HEADER_H_ +#define SOURCE_LATEST_VERSION_SPIRV_HEADER_H_ #include "spirv/unified1/spirv.h" -#endif // LIBSPIRV_LATEST_VERSION_SPIRV_HEADER_H_ +#endif // SOURCE_LATEST_VERSION_SPIRV_HEADER_H_ diff --git a/3rdparty/spirv-tools/source/libspirv.cpp b/3rdparty/spirv-tools/source/libspirv.cpp index d0e3fe28c..cbbc4c908 100644 --- a/3rdparty/spirv-tools/source/libspirv.cpp +++ b/3rdparty/spirv-tools/source/libspirv.cpp @@ -14,7 +14,13 @@ #include "spirv-tools/libspirv.hpp" -#include "table.h" +#include + +#include +#include +#include + +#include "source/table.h" namespace spvtools { @@ -35,7 +41,7 @@ Context& Context::operator=(Context&& other) { Context::~Context() { spvContextDestroy(context_); } void Context::SetMessageConsumer(MessageConsumer consumer) { - libspirv::SetContextMessageConsumer(context_, std::move(consumer)); + SetContextMessageConsumer(context_, std::move(consumer)); } spv_context& Context::CContext() { return context_; } @@ -59,7 +65,7 @@ SpirvTools::SpirvTools(spv_target_env env) : impl_(new Impl(env)) {} SpirvTools::~SpirvTools() {} void SpirvTools::SetMessageConsumer(MessageConsumer consumer) { - libspirv::SetContextMessageConsumer(impl_->context, std::move(consumer)); + SetContextMessageConsumer(impl_->context, std::move(consumer)); } bool SpirvTools::Assemble(const std::string& text, @@ -109,10 +115,17 @@ bool SpirvTools::Validate(const uint32_t* binary, } bool SpirvTools::Validate(const uint32_t* binary, const size_t binary_size, - const spvtools::ValidatorOptions& options) const { + const ValidatorOptions& options) const { spv_const_binary_t the_binary{binary, binary_size}; - return spvValidateWithOptions(impl_->context, options, &the_binary, - nullptr) == SPV_SUCCESS; + spv_diagnostic diagnostic = nullptr; + bool valid = spvValidateWithOptions(impl_->context, options, &the_binary, + &diagnostic) == SPV_SUCCESS; + if (!valid && impl_->context->consumer) { + impl_->context->consumer.operator()( + SPV_MSG_ERROR, nullptr, diagnostic->position, diagnostic->error); + } + spvDiagnosticDestroy(diagnostic); + return valid; } } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/link/linker.cpp b/3rdparty/spirv-tools/source/link/linker.cpp index 49cda4bd3..f28b7595a 100644 --- a/3rdparty/spirv-tools/source/link/linker.cpp +++ b/3rdparty/spirv-tools/source/link/linker.cpp @@ -14,33 +14,36 @@ #include "spirv-tools/linker.hpp" +#include #include #include - -#include #include +#include +#include #include #include +#include #include -#include "assembly_grammar.h" -#include "diagnostic.h" -#include "opt/build_module.h" -#include "opt/compact_ids_pass.h" -#include "opt/decoration_manager.h" -#include "opt/ir_loader.h" -#include "opt/make_unique.h" -#include "opt/pass_manager.h" -#include "opt/remove_duplicates_pass.h" +#include "source/assembly_grammar.h" +#include "source/diagnostic.h" +#include "source/opt/build_module.h" +#include "source/opt/compact_ids_pass.h" +#include "source/opt/decoration_manager.h" +#include "source/opt/ir_loader.h" +#include "source/opt/pass_manager.h" +#include "source/opt/remove_duplicates_pass.h" +#include "source/spirv_target_env.h" +#include "source/util/make_unique.h" #include "spirv-tools/libspirv.hpp" -#include "spirv_target_env.h" namespace spvtools { +namespace { -using ir::Instruction; -using ir::IRContext; -using ir::Module; -using ir::Operand; +using opt::IRContext; +using opt::Instruction; +using opt::Module; +using opt::Operand; using opt::PassManager; using opt::RemoveDuplicatesPass; using opt::analysis::DecorationManager; @@ -72,9 +75,9 @@ using LinkageTable = std::vector; // Both |modules| and |max_id_bound| should not be null, and |modules| should // not be empty either. Furthermore |modules| should not contain any null // pointers. -static spv_result_t ShiftIdsInModules(const MessageConsumer& consumer, - std::vector* modules, - uint32_t* max_id_bound); +spv_result_t ShiftIdsInModules(const MessageConsumer& consumer, + std::vector* modules, + uint32_t* max_id_bound); // Generates the header for the linked module and returns it in |header|. // @@ -84,19 +87,18 @@ static spv_result_t ShiftIdsInModules(const MessageConsumer& consumer, // TODO(pierremoreau): What to do when binaries use different versions of // SPIR-V? For now, use the max of all versions found in // the input modules. -static spv_result_t GenerateHeader(const MessageConsumer& consumer, - const std::vector& modules, - uint32_t max_id_bound, - ir::ModuleHeader* header); +spv_result_t GenerateHeader(const MessageConsumer& consumer, + const std::vector& modules, + uint32_t max_id_bound, opt::ModuleHeader* header); // Merge all the modules from |in_modules| into a single module owned by // |linked_context|. // // |linked_context| should not be null. -static spv_result_t MergeModules(const MessageConsumer& consumer, - const std::vector& in_modules, - const libspirv::AssemblyGrammar& grammar, - IRContext* linked_context); +spv_result_t MergeModules(const MessageConsumer& consumer, + const std::vector& in_modules, + const AssemblyGrammar& grammar, + IRContext* linked_context); // Compute all pairs of import and export and return it in |linkings_to_do|. // @@ -107,20 +109,21 @@ static spv_result_t MergeModules(const MessageConsumer& consumer, // applied to a single ID.) // TODO(pierremoreau): What should be the proper behaviour with built-in // symbols? -static spv_result_t GetImportExportPairs( - const MessageConsumer& consumer, const ir::IRContext& linked_context, - const DefUseManager& def_use_manager, - const DecorationManager& decoration_manager, bool allow_partial_linkage, - LinkageTable* linkings_to_do); +spv_result_t GetImportExportPairs(const MessageConsumer& consumer, + const opt::IRContext& linked_context, + const DefUseManager& def_use_manager, + const DecorationManager& decoration_manager, + bool allow_partial_linkage, + LinkageTable* linkings_to_do); // Checks that for each pair of import and export, the import and export have // the same type as well as the same decorations. // // TODO(pierremoreau): Decorations on functions parameters are currently not // checked. -static spv_result_t CheckImportExportCompatibility( - const MessageConsumer& consumer, const LinkageTable& linkings_to_do, - ir::IRContext* context); +spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer, + const LinkageTable& linkings_to_do, + opt::IRContext* context); // Remove linkage specific instructions, such as prototypes of imported // functions, declarations of imported variables, import (and export if @@ -134,152 +137,29 @@ static spv_result_t CheckImportExportCompatibility( // applied to a single ID.) // TODO(pierremoreau): Run a pass for removing dead instructions, for example // OpName for prototypes of imported funcions. -static spv_result_t RemoveLinkageSpecificInstructions( +spv_result_t RemoveLinkageSpecificInstructions( const MessageConsumer& consumer, const LinkerOptions& options, const LinkageTable& linkings_to_do, DecorationManager* decoration_manager, - ir::IRContext* linked_context); + opt::IRContext* linked_context); // Verify that the unique ids of each instruction in |linked_context| (i.e. the // merged module) are truly unique. Does not check the validity of other ids -static spv_result_t VerifyIds(const MessageConsumer& consumer, - ir::IRContext* linked_context); +spv_result_t VerifyIds(const MessageConsumer& consumer, + opt::IRContext* linked_context); -spv_result_t Link(const Context& context, - const std::vector>& binaries, - std::vector* linked_binary, - const LinkerOptions& options) { - std::vector binary_ptrs; - binary_ptrs.reserve(binaries.size()); - std::vector binary_sizes; - binary_sizes.reserve(binaries.size()); - - for (const auto& binary : binaries) { - binary_ptrs.push_back(binary.data()); - binary_sizes.push_back(binary.size()); - } - - return Link(context, binary_ptrs.data(), binary_sizes.data(), binaries.size(), - linked_binary, options); -} - -spv_result_t Link(const Context& context, const uint32_t* const* binaries, - const size_t* binary_sizes, size_t num_binaries, - std::vector* linked_binary, - const LinkerOptions& options) { - spv_position_t position = {}; - const spv_context& c_context = context.CContext(); - const MessageConsumer& consumer = c_context->consumer; - - linked_binary->clear(); - if (num_binaries == 0u) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_BINARY) - << "No modules were given."; - - std::vector> ir_contexts; - std::vector modules; - modules.reserve(num_binaries); - for (size_t i = 0u; i < num_binaries; ++i) { - const uint32_t schema = binaries[i][4u]; - if (schema != 0u) { - position.index = 4u; - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_BINARY) - << "Schema is non-zero for module " << i << "."; - } - - std::unique_ptr ir_context = BuildModule( - c_context->target_env, consumer, binaries[i], binary_sizes[i]); - if (ir_context == nullptr) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_BINARY) - << "Failed to build a module out of " << ir_contexts.size() << "."; - modules.push_back(ir_context->module()); - ir_contexts.push_back(std::move(ir_context)); - } - - // Phase 1: Shift the IDs used in each binary so that they occupy a disjoint - // range from the other binaries, and compute the new ID bound. - uint32_t max_id_bound = 0u; - spv_result_t res = ShiftIdsInModules(consumer, &modules, &max_id_bound); - if (res != SPV_SUCCESS) return res; - - // Phase 2: Generate the header - ir::ModuleHeader header; - res = GenerateHeader(consumer, modules, max_id_bound, &header); - if (res != SPV_SUCCESS) return res; - IRContext linked_context(c_context->target_env, consumer); - linked_context.module()->SetHeader(header); - - // Phase 3: Merge all the binaries into a single one. - libspirv::AssemblyGrammar grammar(c_context); - res = MergeModules(consumer, modules, grammar, &linked_context); - if (res != SPV_SUCCESS) return res; - - if (options.GetVerifyIds()) { - res = VerifyIds(consumer, &linked_context); - if (res != SPV_SUCCESS) return res; - } - - // Phase 4: Find the import/export pairs - LinkageTable linkings_to_do; - res = GetImportExportPairs(consumer, linked_context, - *linked_context.get_def_use_mgr(), - *linked_context.get_decoration_mgr(), - options.GetAllowPartialLinkage(), &linkings_to_do); - if (res != SPV_SUCCESS) return res; - - // Phase 5: Ensure the import and export have the same types and decorations. - res = - CheckImportExportCompatibility(consumer, linkings_to_do, &linked_context); - if (res != SPV_SUCCESS) return res; - - // Phase 6: Remove duplicates - PassManager manager; - manager.SetMessageConsumer(consumer); - manager.AddPass(); - opt::Pass::Status pass_res = manager.Run(&linked_context); - if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA; - - // Phase 7: Rematch import variables/functions to export variables/functions - for (const auto& linking_entry : linkings_to_do) - linked_context.ReplaceAllUsesWith(linking_entry.imported_symbol.id, - linking_entry.exported_symbol.id); - - // Phase 8: Remove linkage specific instructions, such as import/export - // attributes, linkage capability, etc. if applicable - res = RemoveLinkageSpecificInstructions(consumer, options, linkings_to_do, - linked_context.get_decoration_mgr(), - &linked_context); - if (res != SPV_SUCCESS) return res; - - // Phase 9: Compact the IDs used in the module - manager.AddPass(); - pass_res = manager.Run(&linked_context); - if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA; - - // Phase 10: Output the module - linked_context.module()->ToBinary(linked_binary, true); - - return SPV_SUCCESS; -} - -static spv_result_t ShiftIdsInModules(const MessageConsumer& consumer, - std::vector* modules, - uint32_t* max_id_bound) { +spv_result_t ShiftIdsInModules(const MessageConsumer& consumer, + std::vector* modules, + uint32_t* max_id_bound) { spv_position_t position = {}; if (modules == nullptr) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_DATA) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) << "|modules| of ShiftIdsInModules should not be null."; if (modules->empty()) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_DATA) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) << "|modules| of ShiftIdsInModules should not be empty."; if (max_id_bound == nullptr) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_DATA) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) << "|max_id_bound| of ShiftIdsInModules should not be null."; uint32_t id_bound = modules->front()->IdBound() - 1u; @@ -291,17 +171,16 @@ static spv_result_t ShiftIdsInModules(const MessageConsumer& consumer, }); id_bound += module->IdBound() - 1u; if (id_bound > 0x3FFFFF) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_ID) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_ID) << "The limit of IDs, 4194303, was exceeded:" << " " << id_bound << " is the current ID bound."; // Invalidate the DefUseManager - module->context()->InvalidateAnalyses(ir::IRContext::kAnalysisDefUse); + module->context()->InvalidateAnalyses(opt::IRContext::kAnalysisDefUse); } ++id_bound; if (id_bound > 0x3FFFFF) - return libspirv::DiagnosticStream(position, consumer, SPV_ERROR_INVALID_ID) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_ID) << "The limit of IDs, 4194303, was exceeded:" << " " << id_bound << " is the current ID bound."; @@ -310,19 +189,16 @@ static spv_result_t ShiftIdsInModules(const MessageConsumer& consumer, return SPV_SUCCESS; } -static spv_result_t GenerateHeader(const MessageConsumer& consumer, - const std::vector& modules, - uint32_t max_id_bound, - ir::ModuleHeader* header) { +spv_result_t GenerateHeader(const MessageConsumer& consumer, + const std::vector& modules, + uint32_t max_id_bound, opt::ModuleHeader* header) { spv_position_t position = {}; if (modules.empty()) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_DATA) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) << "|modules| of GenerateHeader should not be empty."; if (max_id_bound == 0u) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_DATA) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) << "|max_id_bound| of GenerateHeader should not be null."; uint32_t version = 0u; @@ -338,15 +214,14 @@ static spv_result_t GenerateHeader(const MessageConsumer& consumer, return SPV_SUCCESS; } -static spv_result_t MergeModules(const MessageConsumer& consumer, - const std::vector& input_modules, - const libspirv::AssemblyGrammar& grammar, - IRContext* linked_context) { +spv_result_t MergeModules(const MessageConsumer& consumer, + const std::vector& input_modules, + const AssemblyGrammar& grammar, + IRContext* linked_context) { spv_position_t position = {}; if (linked_context == nullptr) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_DATA) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) << "|linked_module| of MergeModules should not be null."; Module* linked_module = linked_context->module(); @@ -384,8 +259,7 @@ static spv_result_t MergeModules(const MessageConsumer& consumer, grammar.lookupOperand(SPV_OPERAND_TYPE_ADDRESSING_MODEL, memory_model_inst->GetSingleWordOperand(0u), ¤t_desc); - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INTERNAL) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL) << "Conflicting addressing models: " << initial_desc->name << " vs " << current_desc->name << "."; } @@ -396,8 +270,7 @@ static spv_result_t MergeModules(const MessageConsumer& consumer, grammar.lookupOperand(SPV_OPERAND_TYPE_MEMORY_MODEL, memory_model_inst->GetSingleWordOperand(1u), ¤t_desc); - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INTERNAL) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL) << "Conflicting memory models: " << initial_desc->name << " vs " << current_desc->name << "."; } @@ -422,8 +295,7 @@ static spv_result_t MergeModules(const MessageConsumer& consumer, if (i != entry_points.end()) { spv_operand_desc desc = nullptr; grammar.lookupOperand(SPV_OPERAND_TYPE_EXECUTION_MODEL, model, &desc); - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INTERNAL) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL) << "The entry point \"" << name << "\", with execution model " << desc->name << ", was already defined."; } @@ -456,10 +328,11 @@ static spv_result_t MergeModules(const MessageConsumer& consumer, // OpModuleProcessed instruction about the linking step. if (linked_module->version() >= 0x10100) { const std::string processed_string("Linked by SPIR-V Tools Linker"); - const size_t words_nb = - processed_string.size() / 4u + (processed_string.size() % 4u != 0u); - std::vector processed_words(words_nb, 0u); - std::memcpy(processed_words.data(), processed_string.data(), words_nb * 4u); + const auto num_chars = processed_string.size(); + // Compute num words, accommodate the terminating null character. + const auto num_words = (num_chars + 1 + 3) / 4; + std::vector processed_words(num_words, 0u); + std::memcpy(processed_words.data(), processed_string.data(), num_chars); linked_module->AddDebug3Inst(std::unique_ptr( new Instruction(linked_context, SpvOpModuleProcessed, 0u, 0u, {{SPV_OPERAND_TYPE_LITERAL_STRING, processed_words}}))); @@ -482,15 +355,14 @@ static spv_result_t MergeModules(const MessageConsumer& consumer, } } if (num_global_values > 0xFFFF) - return libspirv::DiagnosticStream(position, consumer, SPV_ERROR_INTERNAL) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL) << "The limit of global values, 65535, was exceeded;" << " " << num_global_values << " global values were found."; // Process functions and their basic blocks for (const auto& module : input_modules) { for (const auto& func : *module) { - std::unique_ptr cloned_func(func.Clone(linked_context)); - cloned_func->SetParent(linked_module); + std::unique_ptr cloned_func(func.Clone(linked_context)); linked_module->AddFunction(std::move(cloned_func)); } } @@ -498,16 +370,16 @@ static spv_result_t MergeModules(const MessageConsumer& consumer, return SPV_SUCCESS; } -static spv_result_t GetImportExportPairs( - const MessageConsumer& consumer, const ir::IRContext& linked_context, - const DefUseManager& def_use_manager, - const DecorationManager& decoration_manager, bool allow_partial_linkage, - LinkageTable* linkings_to_do) { +spv_result_t GetImportExportPairs(const MessageConsumer& consumer, + const opt::IRContext& linked_context, + const DefUseManager& def_use_manager, + const DecorationManager& decoration_manager, + bool allow_partial_linkage, + LinkageTable* linkings_to_do) { spv_position_t position = {}; if (linkings_to_do == nullptr) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_DATA) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) << "|linkings_to_do| of GetImportExportPairs should not be empty."; std::vector imports; @@ -546,8 +418,7 @@ static spv_result_t GetImportExportPairs( // types. const Instruction* def_inst = def_use_manager.GetDef(id); if (def_inst == nullptr) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_BINARY) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) << "ID " << id << " is never defined:\n"; if (def_inst->opcode() == SpvOpVariable) { @@ -565,8 +436,7 @@ static spv_result_t GetImportExportPairs( }); } } else { - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_BINARY) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) << "Only global variables and functions can be decorated using" << " LinkageAttributes; " << id << " is neither of them.\n"; } @@ -583,12 +453,10 @@ static spv_result_t GetImportExportPairs( const auto& exp = exports.find(import.name); if (exp != exports.end()) possible_exports = exp->second; if (possible_exports.empty() && !allow_partial_linkage) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_BINARY) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) << "Unresolved external reference to \"" << import.name << "\"."; else if (possible_exports.size() > 1u) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_BINARY) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) << "Too many external references, " << possible_exports.size() << ", were found for \"" << import.name << "\"."; @@ -599,9 +467,9 @@ static spv_result_t GetImportExportPairs( return SPV_SUCCESS; } -static spv_result_t CheckImportExportCompatibility( - const MessageConsumer& consumer, const LinkageTable& linkings_to_do, - ir::IRContext* context) { +spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer, + const LinkageTable& linkings_to_do, + opt::IRContext* context) { spv_position_t position = {}; // Ensure th import and export types are the same. @@ -612,8 +480,7 @@ static spv_result_t CheckImportExportCompatibility( *def_use_manager.GetDef(linking_entry.imported_symbol.type_id), *def_use_manager.GetDef(linking_entry.exported_symbol.type_id), context)) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_BINARY) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) << "Type mismatch on symbol \"" << linking_entry.imported_symbol.name << "\" between imported variable/function %" @@ -626,8 +493,7 @@ static spv_result_t CheckImportExportCompatibility( for (const auto& linking_entry : linkings_to_do) { if (!decoration_manager.HaveTheSameDecorations( linking_entry.imported_symbol.id, linking_entry.exported_symbol.id)) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_BINARY) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) << "Decorations mismatch on symbol \"" << linking_entry.imported_symbol.name << "\" between imported variable/function %" @@ -644,20 +510,18 @@ static spv_result_t CheckImportExportCompatibility( return SPV_SUCCESS; } -static spv_result_t RemoveLinkageSpecificInstructions( +spv_result_t RemoveLinkageSpecificInstructions( const MessageConsumer& consumer, const LinkerOptions& options, const LinkageTable& linkings_to_do, DecorationManager* decoration_manager, - ir::IRContext* linked_context) { + opt::IRContext* linked_context) { spv_position_t position = {}; if (decoration_manager == nullptr) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_DATA) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) << "|decoration_manager| of RemoveLinkageSpecificInstructions " "should not be empty."; if (linked_context == nullptr) - return libspirv::DiagnosticStream(position, consumer, - SPV_ERROR_INVALID_DATA) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) << "|linked_module| of RemoveLinkageSpecificInstructions should not " "be empty."; @@ -767,11 +631,11 @@ static spv_result_t RemoveLinkageSpecificInstructions( } spv_result_t VerifyIds(const MessageConsumer& consumer, - ir::IRContext* linked_context) { + opt::IRContext* linked_context) { std::unordered_set ids; bool ok = true; linked_context->module()->ForEachInst( - [&ids, &ok](const ir::Instruction* inst) { + [&ids, &ok](const opt::Instruction* inst) { ok &= ids.insert(inst->unique_id()).second; }); @@ -783,4 +647,123 @@ spv_result_t VerifyIds(const MessageConsumer& consumer, return SPV_SUCCESS; } +} // namespace + +spv_result_t Link(const Context& context, + const std::vector>& binaries, + std::vector* linked_binary, + const LinkerOptions& options) { + std::vector binary_ptrs; + binary_ptrs.reserve(binaries.size()); + std::vector binary_sizes; + binary_sizes.reserve(binaries.size()); + + for (const auto& binary : binaries) { + binary_ptrs.push_back(binary.data()); + binary_sizes.push_back(binary.size()); + } + + return Link(context, binary_ptrs.data(), binary_sizes.data(), binaries.size(), + linked_binary, options); +} + +spv_result_t Link(const Context& context, const uint32_t* const* binaries, + const size_t* binary_sizes, size_t num_binaries, + std::vector* linked_binary, + const LinkerOptions& options) { + spv_position_t position = {}; + const spv_context& c_context = context.CContext(); + const MessageConsumer& consumer = c_context->consumer; + + linked_binary->clear(); + if (num_binaries == 0u) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) + << "No modules were given."; + + std::vector> ir_contexts; + std::vector modules; + modules.reserve(num_binaries); + for (size_t i = 0u; i < num_binaries; ++i) { + const uint32_t schema = binaries[i][4u]; + if (schema != 0u) { + position.index = 4u; + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) + << "Schema is non-zero for module " << i << "."; + } + + std::unique_ptr ir_context = BuildModule( + c_context->target_env, consumer, binaries[i], binary_sizes[i]); + if (ir_context == nullptr) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) + << "Failed to build a module out of " << ir_contexts.size() << "."; + modules.push_back(ir_context->module()); + ir_contexts.push_back(std::move(ir_context)); + } + + // Phase 1: Shift the IDs used in each binary so that they occupy a disjoint + // range from the other binaries, and compute the new ID bound. + uint32_t max_id_bound = 0u; + spv_result_t res = ShiftIdsInModules(consumer, &modules, &max_id_bound); + if (res != SPV_SUCCESS) return res; + + // Phase 2: Generate the header + opt::ModuleHeader header; + res = GenerateHeader(consumer, modules, max_id_bound, &header); + if (res != SPV_SUCCESS) return res; + IRContext linked_context(c_context->target_env, consumer); + linked_context.module()->SetHeader(header); + + // Phase 3: Merge all the binaries into a single one. + AssemblyGrammar grammar(c_context); + res = MergeModules(consumer, modules, grammar, &linked_context); + if (res != SPV_SUCCESS) return res; + + if (options.GetVerifyIds()) { + res = VerifyIds(consumer, &linked_context); + if (res != SPV_SUCCESS) return res; + } + + // Phase 4: Find the import/export pairs + LinkageTable linkings_to_do; + res = GetImportExportPairs(consumer, linked_context, + *linked_context.get_def_use_mgr(), + *linked_context.get_decoration_mgr(), + options.GetAllowPartialLinkage(), &linkings_to_do); + if (res != SPV_SUCCESS) return res; + + // Phase 5: Ensure the import and export have the same types and decorations. + res = + CheckImportExportCompatibility(consumer, linkings_to_do, &linked_context); + if (res != SPV_SUCCESS) return res; + + // Phase 6: Remove duplicates + PassManager manager; + manager.SetMessageConsumer(consumer); + manager.AddPass(); + opt::Pass::Status pass_res = manager.Run(&linked_context); + if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA; + + // Phase 7: Rematch import variables/functions to export variables/functions + for (const auto& linking_entry : linkings_to_do) + linked_context.ReplaceAllUsesWith(linking_entry.imported_symbol.id, + linking_entry.exported_symbol.id); + + // Phase 8: Remove linkage specific instructions, such as import/export + // attributes, linkage capability, etc. if applicable + res = RemoveLinkageSpecificInstructions(consumer, options, linkings_to_do, + linked_context.get_decoration_mgr(), + &linked_context); + if (res != SPV_SUCCESS) return res; + + // Phase 9: Compact the IDs used in the module + manager.AddPass(); + pass_res = manager.Run(&linked_context); + if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA; + + // Phase 10: Output the module + linked_context.module()->ToBinary(linked_binary, true); + + return SPV_SUCCESS; +} + } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/macro.h b/3rdparty/spirv-tools/source/macro.h index 810db6e61..7219ffed1 100644 --- a/3rdparty/spirv-tools/source/macro.h +++ b/3rdparty/spirv-tools/source/macro.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_MACRO_H_ -#define LIBSPIRV_MACRO_H_ +#ifndef SOURCE_MACRO_H_ +#define SOURCE_MACRO_H_ // Evaluates to the number of elements of array A. // @@ -22,4 +22,4 @@ // std::array::size. #define ARRAY_SIZE(A) (static_cast(sizeof(A) / sizeof(A[0]))) -#endif // LIBSPIRV_MACRO_H_ +#endif // SOURCE_MACRO_H_ diff --git a/3rdparty/spirv-tools/source/name_mapper.cpp b/3rdparty/spirv-tools/source/name_mapper.cpp index a8b5a7c90..43fdfb34b 100644 --- a/3rdparty/spirv-tools/source/name_mapper.cpp +++ b/3rdparty/spirv-tools/source/name_mapper.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "name_mapper.h" +#include "source/name_mapper.h" #include #include @@ -24,9 +24,10 @@ #include "spirv-tools/libspirv.h" -#include "latest_version_spirv_header.h" -#include "parsed_operand.h" +#include "source/latest_version_spirv_header.h" +#include "source/parsed_operand.h" +namespace spvtools { namespace { // Converts a uint32_t to its string decimal representation. @@ -40,14 +41,12 @@ std::string to_string(uint32_t id) { } // anonymous namespace -namespace libspirv { - NameMapper GetTrivialNameMapper() { return to_string; } FriendlyNameMapper::FriendlyNameMapper(const spv_const_context context, const uint32_t* code, const size_t wordCount) - : grammar_(libspirv::AssemblyGrammar(context)) { + : grammar_(AssemblyGrammar(context)) { spv_diagnostic diag = nullptr; // We don't care if the parse fails. spvBinaryParse(context, this, code, wordCount, nullptr, @@ -329,4 +328,4 @@ std::string FriendlyNameMapper::NameForEnumOperand(spv_operand_type_t type, } } -} // namespace libspirv +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/name_mapper.h b/3rdparty/spirv-tools/source/name_mapper.h index 8afac4241..6902141b1 100644 --- a/3rdparty/spirv-tools/source/name_mapper.h +++ b/3rdparty/spirv-tools/source/name_mapper.h @@ -12,18 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_NAME_MAPPER_H_ -#define LIBSPIRV_NAME_MAPPER_H_ +#ifndef SOURCE_NAME_MAPPER_H_ +#define SOURCE_NAME_MAPPER_H_ #include #include #include #include -#include "assembly_grammar.h" +#include "source/assembly_grammar.h" #include "spirv-tools/libspirv.h" -namespace libspirv { +namespace spvtools { // A NameMapper maps SPIR-V Id values to names. Each name is valid to use in // SPIR-V assembly. The mapping is one-to-one, i.e. no two Ids map to the same @@ -114,9 +114,9 @@ class FriendlyNameMapper { // The set of names that have a mapping in name_for_id_; std::unordered_set used_names_; // The assembly grammar for the current context. - const libspirv::AssemblyGrammar grammar_; + const AssemblyGrammar grammar_; }; -} // namespace libspirv +} // namespace spvtools -#endif // _LIBSPIRV_NAME_MAPPER_H_ +#endif // SOURCE_NAME_MAPPER_H_ diff --git a/3rdparty/spirv-tools/source/opcode.cpp b/3rdparty/spirv-tools/source/opcode.cpp index c73f14d3a..af34b6460 100644 --- a/3rdparty/spirv-tools/source/opcode.cpp +++ b/3rdparty/spirv-tools/source/opcode.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "opcode.h" +#include "source/opcode.h" #include #include @@ -20,12 +20,12 @@ #include #include -#include "instruction.h" -#include "macro.h" +#include "source/instruction.h" +#include "source/macro.h" +#include "source/spirv_constant.h" +#include "source/spirv_endian.h" +#include "source/spirv_target_env.h" #include "spirv-tools/libspirv.h" -#include "spirv_constant.h" -#include "spirv_endian.h" -#include "spirv_target_env.h" namespace { struct OpcodeDescPtrLen { @@ -33,7 +33,7 @@ struct OpcodeDescPtrLen { uint32_t len; }; -#include "core.insts-unified1.inc" // defines kOpcodeTableEntries_1_3 +#include "core.insts-unified1.inc" static const spv_opcode_table_t kOpcodeTable = {ARRAY_SIZE(kOpcodeTableEntries), kOpcodeTableEntries}; @@ -108,7 +108,7 @@ spv_result_t spvOpcodeTableNameLookup(spv_target_env env, // is indeed requested in the SPIR-V code; checking that should be // validator's work. if ((spvVersionForTargetEnv(env) >= entry.minVersion || - entry.numExtensions > 0u) && + entry.numExtensions > 0u || entry.numCapabilities > 0u) && nameLength == strlen(entry.name) && !strncmp(name, entry.name, nameLength)) { // NOTE: Found out Opcode! @@ -153,7 +153,7 @@ spv_result_t spvOpcodeTableValueLookup(spv_target_env env, // is indeed requested in the SPIR-V code; checking that should be // validator's work. if (spvVersionForTargetEnv(env) >= it->minVersion || - it->numExtensions > 0u) { + it->numExtensions > 0u || it->numCapabilities > 0u) { *pEntry = it; return SPV_SUCCESS; } @@ -454,3 +454,132 @@ bool spvOpcodeIsBaseOpaqueType(SpvOp opcode) { return false; } } + +bool spvOpcodeIsNonUniformGroupOperation(SpvOp opcode) { + switch (opcode) { + case SpvOpGroupNonUniformElect: + case SpvOpGroupNonUniformAll: + case SpvOpGroupNonUniformAny: + case SpvOpGroupNonUniformAllEqual: + case SpvOpGroupNonUniformBroadcast: + case SpvOpGroupNonUniformBroadcastFirst: + case SpvOpGroupNonUniformBallot: + case SpvOpGroupNonUniformInverseBallot: + case SpvOpGroupNonUniformBallotBitExtract: + case SpvOpGroupNonUniformBallotBitCount: + case SpvOpGroupNonUniformBallotFindLSB: + case SpvOpGroupNonUniformBallotFindMSB: + case SpvOpGroupNonUniformShuffle: + case SpvOpGroupNonUniformShuffleXor: + case SpvOpGroupNonUniformShuffleUp: + case SpvOpGroupNonUniformShuffleDown: + case SpvOpGroupNonUniformIAdd: + case SpvOpGroupNonUniformFAdd: + case SpvOpGroupNonUniformIMul: + case SpvOpGroupNonUniformFMul: + case SpvOpGroupNonUniformSMin: + case SpvOpGroupNonUniformUMin: + case SpvOpGroupNonUniformFMin: + case SpvOpGroupNonUniformSMax: + case SpvOpGroupNonUniformUMax: + case SpvOpGroupNonUniformFMax: + case SpvOpGroupNonUniformBitwiseAnd: + case SpvOpGroupNonUniformBitwiseOr: + case SpvOpGroupNonUniformBitwiseXor: + case SpvOpGroupNonUniformLogicalAnd: + case SpvOpGroupNonUniformLogicalOr: + case SpvOpGroupNonUniformLogicalXor: + case SpvOpGroupNonUniformQuadBroadcast: + case SpvOpGroupNonUniformQuadSwap: + return true; + default: + return false; + } +} + +bool spvOpcodeIsScalarizable(SpvOp opcode) { + switch (opcode) { + case SpvOpPhi: + case SpvOpCopyObject: + case SpvOpConvertFToU: + case SpvOpConvertFToS: + case SpvOpConvertSToF: + case SpvOpConvertUToF: + case SpvOpUConvert: + case SpvOpSConvert: + case SpvOpFConvert: + case SpvOpQuantizeToF16: + case SpvOpVectorInsertDynamic: + case SpvOpSNegate: + case SpvOpFNegate: + case SpvOpIAdd: + case SpvOpFAdd: + case SpvOpISub: + case SpvOpFSub: + case SpvOpIMul: + case SpvOpFMul: + case SpvOpUDiv: + case SpvOpSDiv: + case SpvOpFDiv: + case SpvOpUMod: + case SpvOpSRem: + case SpvOpSMod: + case SpvOpFRem: + case SpvOpFMod: + case SpvOpVectorTimesScalar: + case SpvOpIAddCarry: + case SpvOpISubBorrow: + case SpvOpUMulExtended: + case SpvOpSMulExtended: + case SpvOpShiftRightLogical: + case SpvOpShiftRightArithmetic: + case SpvOpShiftLeftLogical: + case SpvOpBitwiseOr: + case SpvOpBitwiseAnd: + case SpvOpNot: + case SpvOpBitFieldInsert: + case SpvOpBitFieldSExtract: + case SpvOpBitFieldUExtract: + case SpvOpBitReverse: + case SpvOpBitCount: + case SpvOpIsNan: + case SpvOpIsInf: + case SpvOpIsFinite: + case SpvOpIsNormal: + case SpvOpSignBitSet: + case SpvOpLessOrGreater: + case SpvOpOrdered: + case SpvOpUnordered: + case SpvOpLogicalEqual: + case SpvOpLogicalNotEqual: + case SpvOpLogicalOr: + case SpvOpLogicalAnd: + case SpvOpLogicalNot: + case SpvOpSelect: + case SpvOpIEqual: + case SpvOpINotEqual: + case SpvOpUGreaterThan: + case SpvOpSGreaterThan: + case SpvOpUGreaterThanEqual: + case SpvOpSGreaterThanEqual: + case SpvOpULessThan: + case SpvOpSLessThan: + case SpvOpULessThanEqual: + case SpvOpSLessThanEqual: + case SpvOpFOrdEqual: + case SpvOpFUnordEqual: + case SpvOpFOrdNotEqual: + case SpvOpFUnordNotEqual: + case SpvOpFOrdLessThan: + case SpvOpFUnordLessThan: + case SpvOpFOrdGreaterThan: + case SpvOpFUnordGreaterThan: + case SpvOpFOrdLessThanEqual: + case SpvOpFUnordLessThanEqual: + case SpvOpFOrdGreaterThanEqual: + case SpvOpFUnordGreaterThanEqual: + return true; + default: + return false; + } +} diff --git a/3rdparty/spirv-tools/source/opcode.h b/3rdparty/spirv-tools/source/opcode.h index 9b585137e..5643a64c8 100644 --- a/3rdparty/spirv-tools/source/opcode.h +++ b/3rdparty/spirv-tools/source/opcode.h @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPCODE_H_ -#define LIBSPIRV_OPCODE_H_ +#ifndef SOURCE_OPCODE_H_ +#define SOURCE_OPCODE_H_ -#include "instruction.h" -#include "latest_version_spirv_header.h" +#include "source/instruction.h" +#include "source/latest_version_spirv_header.h" +#include "source/table.h" #include "spirv-tools/libspirv.h" -#include "table.h" // Returns the name of a registered SPIR-V generator as a null-terminated // string. If the generator is not known, then returns the string "Unknown". @@ -118,4 +118,11 @@ bool spvOpcodeIsBlockTerminator(SpvOp opcode); // Returns true if the given opcode always defines an opaque type. bool spvOpcodeIsBaseOpaqueType(SpvOp opcode); -#endif // LIBSPIRV_OPCODE_H_ + +// Returns true if the given opcode is a non-uniform group operation. +bool spvOpcodeIsNonUniformGroupOperation(SpvOp opcode); + +// Returns true if the opcode with vector inputs could be divided into a series +// of independent scalar operations that would give the same result. +bool spvOpcodeIsScalarizable(SpvOp opcode); +#endif // SOURCE_OPCODE_H_ diff --git a/3rdparty/spirv-tools/source/operand.cpp b/3rdparty/spirv-tools/source/operand.cpp index 4a2a2d6c0..c97b13fc6 100644 --- a/3rdparty/spirv-tools/source/operand.cpp +++ b/3rdparty/spirv-tools/source/operand.cpp @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "operand.h" +#include "source/operand.h" #include #include #include -#include "macro.h" -#include "spirv_constant.h" -#include "spirv_target_env.h" +#include "source/macro.h" +#include "source/spirv_constant.h" +#include "source/spirv_target_env.h" // For now, assume unified1 contains up to SPIR-V 1.3 and no later // SPIR-V version. @@ -55,16 +55,17 @@ spv_result_t spvOperandTableNameLookup(spv_target_env env, if (type != group.type) continue; for (uint64_t index = 0; index < group.count; ++index) { const auto& entry = group.entries[index]; - // We considers the current operand as available as long as + // We consider the current operand as available as long as // 1. The target environment satisfies the minimal requirement of the // operand; or - // 2. There is at least one extension enabling this operand. + // 2. There is at least one extension enabling this operand; or + // 3. There is at least one capability enabling this operand. // // Note that the second rule assumes the extension enabling this operand // is indeed requested in the SPIR-V code; checking that should be // validator's work. if ((spvVersionForTargetEnv(env) >= entry.minVersion || - entry.numExtensions > 0u) && + entry.numExtensions > 0u || entry.numCapabilities > 0u) && nameLength == strlen(entry.name) && !strncmp(entry.name, name, nameLength)) { *pEntry = &entry; @@ -109,16 +110,17 @@ spv_result_t spvOperandTableValueLookup(spv_target_env env, // opcode value. for (auto it = std::lower_bound(beg, end, needle, comp); it != end && it->value == value; ++it) { - // We considers the current operand as available as long as + // We consider the current operand as available as long as // 1. The target environment satisfies the minimal requirement of the // operand; or - // 2. There is at least one extension enabling this operand. + // 2. There is at least one extension enabling this operand; or + // 3. There is at least one capability enabling this operand. // // Note that the second rule assumes the extension enabling this operand // is indeed requested in the SPIR-V code; checking that should be // validator's work. if (spvVersionForTargetEnv(env) >= it->minVersion || - it->numExtensions > 0u) { + it->numExtensions > 0u || it->numCapabilities > 0u) { *pEntry = it; return SPV_SUCCESS; } @@ -244,8 +246,9 @@ const char* spvOperandTypeStr(spv_operand_type_t type) { void spvPushOperandTypes(const spv_operand_type_t* types, spv_operand_pattern_t* pattern) { const spv_operand_type_t* endTypes; - for (endTypes = types; *endTypes != SPV_OPERAND_TYPE_NONE; ++endTypes) - ; + for (endTypes = types; *endTypes != SPV_OPERAND_TYPE_NONE; ++endTypes) { + } + while (endTypes-- != types) { pattern->push_back(*endTypes); } @@ -413,6 +416,7 @@ std::function spvOperandCanBeForwardDeclaredFunction( std::function out; switch (opcode) { case SpvOpExecutionMode: + case SpvOpExecutionModeId: case SpvOpEntryPoint: case SpvOpName: case SpvOpMemberName: diff --git a/3rdparty/spirv-tools/source/operand.h b/3rdparty/spirv-tools/source/operand.h index 984a62328..76f16f7ae 100644 --- a/3rdparty/spirv-tools/source/operand.h +++ b/3rdparty/spirv-tools/source/operand.h @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPERAND_H_ -#define LIBSPIRV_OPERAND_H_ +#ifndef SOURCE_OPERAND_H_ +#define SOURCE_OPERAND_H_ -#include #include +#include +#include "source/table.h" #include "spirv-tools/libspirv.h" -#include "table.h" // A sequence of operand types. // @@ -138,4 +138,4 @@ bool spvIsIdType(spv_operand_type_t type); std::function spvOperandCanBeForwardDeclaredFunction( SpvOp opcode); -#endif // LIBSPIRV_OPERAND_H_ +#endif // SOURCE_OPERAND_H_ diff --git a/3rdparty/spirv-tools/source/opt/CMakeLists.txt b/3rdparty/spirv-tools/source/opt/CMakeLists.txt index 60028096a..83f92fe88 100644 --- a/3rdparty/spirv-tools/source/opt/CMakeLists.txt +++ b/3rdparty/spirv-tools/source/opt/CMakeLists.txt @@ -19,6 +19,7 @@ add_library(SPIRV-Tools-opt ccp_pass.h cfg_cleanup_pass.h cfg.h + combine_access_chains.h common_uniform_elim_pass.h compact_ids_pass.h composite.h @@ -45,7 +46,6 @@ add_library(SPIRV-Tools-opt inline_exhaustive_pass.h inline_opaque_pass.h inline_pass.h - insert_extract_elim.h instruction.h instruction_list.h ir_builder.h @@ -58,12 +58,15 @@ add_library(SPIRV-Tools-opt local_single_store_elim_pass.h local_ssa_elim_pass.h log.h + loop_dependence.h loop_descriptor.h + loop_fission.h + loop_fusion.h + loop_fusion_pass.h loop_peeling.h loop_unroller.h loop_utils.h loop_unswitch_pass.h - make_unique.h mem_pass.h merge_return_pass.h module.h @@ -73,8 +76,10 @@ add_library(SPIRV-Tools-opt pass_manager.h private_to_local_pass.h propagator.h + reduce_load_size.h redundancy_elimination.h reflect.h + register_pressure.h remove_duplicates_pass.h replace_invalid_opc.h scalar_analysis.h @@ -91,6 +96,7 @@ add_library(SPIRV-Tools-opt types.h unify_const_pass.h value_number_table.h + vector_dce.h workaround1209.h aggressive_dead_code_elim_pass.cpp @@ -100,6 +106,7 @@ add_library(SPIRV-Tools-opt ccp_pass.cpp cfg_cleanup_pass.cpp cfg.cpp + combine_access_chains.cpp common_uniform_elim_pass.cpp compact_ids_pass.cpp composite.cpp @@ -126,7 +133,6 @@ add_library(SPIRV-Tools-opt inline_exhaustive_pass.cpp inline_opaque_pass.cpp inline_pass.cpp - insert_extract_elim.cpp instruction.cpp instruction_list.cpp ir_context.cpp @@ -137,7 +143,12 @@ add_library(SPIRV-Tools-opt local_single_block_elim_pass.cpp local_single_store_elim_pass.cpp local_ssa_elim_pass.cpp + loop_dependence.cpp + loop_dependence_helpers.cpp loop_descriptor.cpp + loop_fission.cpp + loop_fusion.cpp + loop_fusion_pass.cpp loop_peeling.cpp loop_utils.cpp loop_unroller.cpp @@ -150,7 +161,9 @@ add_library(SPIRV-Tools-opt pass_manager.cpp private_to_local_pass.cpp propagator.cpp + reduce_load_size.cpp redundancy_elimination.cpp + register_pressure.cpp remove_duplicates_pass.cpp replace_invalid_opc.cpp scalar_analysis.cpp @@ -166,6 +179,7 @@ add_library(SPIRV-Tools-opt types.cpp unify_const_pass.cpp value_number_table.cpp + vector_dce.cpp workaround1209.cpp ) diff --git a/3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp b/3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp index 014ce0062..faf278aa6 100644 --- a/3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp @@ -1,6 +1,7 @@ // Copyright (c) 2017 The Khronos Group Inc. // Copyright (c) 2017 Valve Corporation // Copyright (c) 2017 LunarG Inc. +// Copyright (c) 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,15 +15,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "aggressive_dead_code_elim_pass.h" - -#include "cfa.h" -#include "iterator.h" -#include "latest_version_glsl_std_450_header.h" -#include "reflect.h" +#include "source/opt/aggressive_dead_code_elim_pass.h" +#include #include +#include "source/cfa.h" +#include "source/latest_version_glsl_std_450_header.h" +#include "source/opt/iterator.h" +#include "source/opt/reflect.h" + namespace spvtools { namespace opt { @@ -33,6 +35,8 @@ const uint32_t kEntryPointFunctionIdInIdx = 1; const uint32_t kSelectionMergeMergeBlockIdInIdx = 0; const uint32_t kLoopMergeMergeBlockIdInIdx = 0; const uint32_t kLoopMergeContinueBlockIdInIdx = 1; +const uint32_t kCopyMemoryTargetAddrInIdx = 0; +const uint32_t kCopyMemorySourceAddrInIdx = 1; // Sorting functor to present annotation instructions in an easy-to-process // order. The functor orders by opcode first and falls back on unique id @@ -44,40 +48,29 @@ const uint32_t kLoopMergeContinueBlockIdInIdx = 1; // SpvOpDecorate // SpvOpMemberDecorate // SpvOpDecorateId +// SpvOpDecorateStringGOOGLE // SpvOpDecorationGroup struct DecorationLess { - bool operator()(const ir::Instruction* lhs, - const ir::Instruction* rhs) const { + bool operator()(const Instruction* lhs, const Instruction* rhs) const { assert(lhs && rhs); SpvOp lhsOp = lhs->opcode(); SpvOp rhsOp = rhs->opcode(); if (lhsOp != rhsOp) { +#define PRIORITY_CASE(opcode) \ + if (lhsOp == opcode && rhsOp != opcode) return true; \ + if (rhsOp == opcode && lhsOp != opcode) return false; // OpGroupDecorate and OpGroupMember decorate are highest priority to // eliminate dead targets early and simplify subsequent checks. - if (lhsOp == SpvOpGroupDecorate && rhsOp != SpvOpGroupDecorate) - return true; - if (rhsOp == SpvOpGroupDecorate && lhsOp != SpvOpGroupDecorate) - return false; - if (lhsOp == SpvOpGroupMemberDecorate && - rhsOp != SpvOpGroupMemberDecorate) - return true; - if (rhsOp == SpvOpGroupMemberDecorate && - lhsOp != SpvOpGroupMemberDecorate) - return false; - if (lhsOp == SpvOpDecorate && rhsOp != SpvOpDecorate) return true; - if (rhsOp == SpvOpDecorate && lhsOp != SpvOpDecorate) return false; - if (lhsOp == SpvOpMemberDecorate && rhsOp != SpvOpMemberDecorate) - return true; - if (rhsOp == SpvOpMemberDecorate && lhsOp != SpvOpMemberDecorate) - return false; - if (lhsOp == SpvOpDecorateId && rhsOp != SpvOpDecorateId) return true; - if (rhsOp == SpvOpDecorateId && lhsOp != SpvOpDecorateId) return false; + PRIORITY_CASE(SpvOpGroupDecorate) + PRIORITY_CASE(SpvOpGroupMemberDecorate) + PRIORITY_CASE(SpvOpDecorate) + PRIORITY_CASE(SpvOpMemberDecorate) + PRIORITY_CASE(SpvOpDecorateId) + PRIORITY_CASE(SpvOpDecorateStringGOOGLE) // OpDecorationGroup is lowest priority to ensure use/def chains remain // usable for instructions that target this group. - if (lhsOp == SpvOpDecorationGroup && rhsOp != SpvOpDecorationGroup) - return true; - if (rhsOp == SpvOpDecorationGroup && lhsOp != SpvOpDecorationGroup) - return false; + PRIORITY_CASE(SpvOpDecorationGroup) +#undef PRIORITY_CASE } // Fall back to maintain total ordering (compare unique ids). @@ -89,23 +82,30 @@ struct DecorationLess { bool AggressiveDCEPass::IsVarOfStorage(uint32_t varId, uint32_t storageClass) { if (varId == 0) return false; - const ir::Instruction* varInst = get_def_use_mgr()->GetDef(varId); + const Instruction* varInst = get_def_use_mgr()->GetDef(varId); const SpvOp op = varInst->opcode(); if (op != SpvOpVariable) return false; const uint32_t varTypeId = varInst->type_id(); - const ir::Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); + const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); if (varTypeInst->opcode() != SpvOpTypePointer) return false; return varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) == storageClass; } bool AggressiveDCEPass::IsLocalVar(uint32_t varId) { - return IsVarOfStorage(varId, SpvStorageClassFunction) || - (IsVarOfStorage(varId, SpvStorageClassPrivate) && private_like_local_); + if (IsVarOfStorage(varId, SpvStorageClassFunction)) { + return true; + } + if (!private_like_local_) { + return false; + } + + return IsVarOfStorage(varId, SpvStorageClassPrivate) || + IsVarOfStorage(varId, SpvStorageClassWorkgroup); } void AggressiveDCEPass::AddStores(uint32_t ptrId) { - get_def_use_mgr()->ForEachUser(ptrId, [this](ir::Instruction* user) { + get_def_use_mgr()->ForEachUser(ptrId, [this, ptrId](Instruction* user) { switch (user->opcode()) { case SpvOpAccessChain: case SpvOpInBoundsAccessChain: @@ -114,6 +114,12 @@ void AggressiveDCEPass::AddStores(uint32_t ptrId) { break; case SpvOpLoad: break; + case SpvOpCopyMemory: + case SpvOpCopyMemorySized: + if (user->GetSingleWordInOperand(kCopyMemoryTargetAddrInIdx) == ptrId) { + AddToWorklist(user); + } + break; // If default, assume it stores e.g. frexp, modf, function call case SpvOpStore: default: @@ -134,7 +140,7 @@ bool AggressiveDCEPass::AllExtensionsSupported() const { return true; } -bool AggressiveDCEPass::IsDead(ir::Instruction* inst) { +bool AggressiveDCEPass::IsDead(Instruction* inst) { if (IsLive(inst)) return false; if (inst->IsBranch() && !IsStructuredHeader(context()->get_instr_block(inst), nullptr, nullptr, nullptr)) @@ -142,16 +148,16 @@ bool AggressiveDCEPass::IsDead(ir::Instruction* inst) { return true; } -bool AggressiveDCEPass::IsTargetDead(ir::Instruction* inst) { +bool AggressiveDCEPass::IsTargetDead(Instruction* inst) { const uint32_t tId = inst->GetSingleWordInOperand(0); - ir::Instruction* tInst = get_def_use_mgr()->GetDef(tId); - if (ir::IsAnnotationInst(tInst->opcode())) { + Instruction* tInst = get_def_use_mgr()->GetDef(tId); + if (IsAnnotationInst(tInst->opcode())) { // This must be a decoration group. We go through annotations in a specific // order. So if this is not used by any group or group member decorates, it // is dead. assert(tInst->opcode() == SpvOpDecorationGroup); bool dead = true; - get_def_use_mgr()->ForEachUser(tInst, [&dead](ir::Instruction* user) { + get_def_use_mgr()->ForEachUser(tInst, [&dead](Instruction* user) { if (user->opcode() == SpvOpGroupDecorate || user->opcode() == SpvOpGroupMemberDecorate) dead = false; @@ -172,14 +178,14 @@ void AggressiveDCEPass::ProcessLoad(uint32_t varId) { live_local_vars_.insert(varId); } -bool AggressiveDCEPass::IsStructuredHeader(ir::BasicBlock* bp, - ir::Instruction** mergeInst, - ir::Instruction** branchInst, +bool AggressiveDCEPass::IsStructuredHeader(BasicBlock* bp, + Instruction** mergeInst, + Instruction** branchInst, uint32_t* mergeBlockId) { if (!bp) return false; - ir::Instruction* mi = bp->GetMergeInst(); + Instruction* mi = bp->GetMergeInst(); if (mi == nullptr) return false; - ir::Instruction* bri = &*bp->tail(); + Instruction* bri = &*bp->tail(); if (branchInst != nullptr) *branchInst = bri; if (mergeInst != nullptr) *mergeInst = mi; if (mergeBlockId != nullptr) *mergeBlockId = mi->GetSingleWordInOperand(0); @@ -187,11 +193,11 @@ bool AggressiveDCEPass::IsStructuredHeader(ir::BasicBlock* bp, } void AggressiveDCEPass::ComputeBlock2HeaderMaps( - std::list& structuredOrder) { + std::list& structuredOrder) { block2headerBranch_.clear(); branch2merge_.clear(); structured_order_index_.clear(); - std::stack currentHeaderBranch; + std::stack currentHeaderBranch; currentHeaderBranch.push(nullptr); uint32_t currentMergeBlockId = 0; uint32_t index = 0; @@ -202,12 +208,12 @@ void AggressiveDCEPass::ComputeBlock2HeaderMaps( // we are leaving the current construct so we must update state if ((*bi)->id() == currentMergeBlockId) { currentHeaderBranch.pop(); - ir::Instruction* chb = currentHeaderBranch.top(); + Instruction* chb = currentHeaderBranch.top(); if (chb != nullptr) currentMergeBlockId = branch2merge_[chb]->GetSingleWordInOperand(0); } - ir::Instruction* mergeInst; - ir::Instruction* branchInst; + Instruction* mergeInst; + Instruction* branchInst; uint32_t mergeBlockId; bool is_header = IsStructuredHeader(*bi, &mergeInst, &branchInst, &mergeBlockId); @@ -229,44 +235,52 @@ void AggressiveDCEPass::ComputeBlock2HeaderMaps( } } -void AggressiveDCEPass::AddBranch(uint32_t labelId, ir::BasicBlock* bp) { - std::unique_ptr newBranch(new ir::Instruction( - context(), SpvOpBranch, 0, 0, - {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}})); - get_def_use_mgr()->AnalyzeInstDefUse(&*newBranch); +void AggressiveDCEPass::AddBranch(uint32_t labelId, BasicBlock* bp) { + std::unique_ptr newBranch( + new Instruction(context(), SpvOpBranch, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}})); + context()->AnalyzeDefUse(&*newBranch); + context()->set_instr_block(&*newBranch, bp); bp->AddInstruction(std::move(newBranch)); } void AggressiveDCEPass::AddBreaksAndContinuesToWorklist( - ir::Instruction* loopMerge) { - ir::BasicBlock* header = context()->get_instr_block(loopMerge); + Instruction* mergeInst) { + assert(mergeInst->opcode() == SpvOpSelectionMerge || + mergeInst->opcode() == SpvOpLoopMerge); + + BasicBlock* header = context()->get_instr_block(mergeInst); uint32_t headerIndex = structured_order_index_[header]; - const uint32_t mergeId = - loopMerge->GetSingleWordInOperand(kLoopMergeMergeBlockIdInIdx); - ir::BasicBlock* merge = context()->get_instr_block(mergeId); + const uint32_t mergeId = mergeInst->GetSingleWordInOperand(0); + BasicBlock* merge = context()->get_instr_block(mergeId); uint32_t mergeIndex = structured_order_index_[merge]; get_def_use_mgr()->ForEachUser( - mergeId, [headerIndex, mergeIndex, this](ir::Instruction* user) { + mergeId, [headerIndex, mergeIndex, this](Instruction* user) { if (!user->IsBranch()) return; - ir::BasicBlock* block = context()->get_instr_block(user); + BasicBlock* block = context()->get_instr_block(user); uint32_t index = structured_order_index_[block]; if (headerIndex < index && index < mergeIndex) { // This is a break from the loop. AddToWorklist(user); // Add branch's merge if there is one. - ir::Instruction* userMerge = branch2merge_[user]; + Instruction* userMerge = branch2merge_[user]; if (userMerge != nullptr) AddToWorklist(userMerge); } }); + + if (mergeInst->opcode() != SpvOpLoopMerge) { + return; + } + + // For loops we need to find the continues as well. const uint32_t contId = - loopMerge->GetSingleWordInOperand(kLoopMergeContinueBlockIdInIdx); - get_def_use_mgr()->ForEachUser(contId, [&contId, - this](ir::Instruction* user) { + mergeInst->GetSingleWordInOperand(kLoopMergeContinueBlockIdInIdx); + get_def_use_mgr()->ForEachUser(contId, [&contId, this](Instruction* user) { SpvOp op = user->opcode(); if (op == SpvOpBranchConditional || op == SpvOpSwitch) { // A conditional branch or switch can only be a continue if it does not // have a merge instruction or its merge block is not the continue block. - ir::Instruction* hdrMerge = branch2merge_[user]; + Instruction* hdrMerge = branch2merge_[user]; if (hdrMerge != nullptr && hdrMerge->opcode() == SpvOpSelectionMerge) { uint32_t hdrMergeId = hdrMerge->GetSingleWordInOperand(kSelectionMergeMergeBlockIdInIdx); @@ -277,10 +291,10 @@ void AggressiveDCEPass::AddBreaksAndContinuesToWorklist( } else if (op == SpvOpBranch) { // An unconditional branch can only be a continue if it is not // branching to its own merge block. - ir::BasicBlock* blk = context()->get_instr_block(user); - ir::Instruction* hdrBranch = block2headerBranch_[blk]; + BasicBlock* blk = context()->get_instr_block(user); + Instruction* hdrBranch = block2headerBranch_[blk]; if (hdrBranch == nullptr) return; - ir::Instruction* hdrMerge = branch2merge_[hdrBranch]; + Instruction* hdrMerge = branch2merge_[hdrBranch]; if (hdrMerge->opcode() == SpvOpLoopMerge) return; uint32_t hdrMergeId = hdrMerge->GetSingleWordInOperand(kSelectionMergeMergeBlockIdInIdx); @@ -292,17 +306,17 @@ void AggressiveDCEPass::AddBreaksAndContinuesToWorklist( }); } -bool AggressiveDCEPass::AggressiveDCE(ir::Function* func) { +bool AggressiveDCEPass::AggressiveDCE(Function* func) { // Mark function parameters as live. AddToWorklist(&func->DefInst()); func->ForEachParam( - [this](const ir::Instruction* param) { - AddToWorklist(const_cast(param)); + [this](const Instruction* param) { + AddToWorklist(const_cast(param)); }, false); // Compute map from block to controlling conditional branch - std::list structuredOrder; + std::list structuredOrder; cfg()->ComputeStructuredOrder(func, &*func->begin(), &structuredOrder); ComputeBlock2HeaderMaps(structuredOrder); bool modified = false; @@ -335,8 +349,21 @@ bool AggressiveDCEPass::AggressiveDCE(ir::Function* func) { (void)GetPtr(&*ii, &varId); // Mark stores as live if their variable is not function scope // and is not private scope. Remember private stores for possible - // later inclusion - if (IsVarOfStorage(varId, SpvStorageClassPrivate)) + // later inclusion. We cannot call IsLocalVar at this point because + // private_like_local_ has not been set yet. + if (IsVarOfStorage(varId, SpvStorageClassPrivate) || + IsVarOfStorage(varId, SpvStorageClassWorkgroup)) + private_stores_.push_back(&*ii); + else if (!IsVarOfStorage(varId, SpvStorageClassFunction)) + AddToWorklist(&*ii); + } break; + case SpvOpCopyMemory: + case SpvOpCopyMemorySized: { + uint32_t varId; + (void)GetPtr(ii->GetSingleWordInOperand(kCopyMemoryTargetAddrInIdx), + &varId); + if (IsVarOfStorage(varId, SpvStorageClassPrivate) || + IsVarOfStorage(varId, SpvStorageClassWorkgroup)) private_stores_.push_back(&*ii); else if (!IsVarOfStorage(varId, SpvStorageClassFunction)) AddToWorklist(&*ii); @@ -354,12 +381,14 @@ bool AggressiveDCEPass::AggressiveDCE(ir::Function* func) { case SpvOpSwitch: case SpvOpBranch: case SpvOpBranchConditional: { - if (assume_branches_live.top()) AddToWorklist(&*ii); + if (assume_branches_live.top()) { + AddToWorklist(&*ii); + } } break; default: { // Function calls, atomics, function params, function returns, etc. // TODO(greg-lunarg): function calls live only if write to non-local - if (!context()->IsCombinatorInstruction(&*ii)) { + if (!ii->IsOpcodeSafeToDelete()) { AddToWorklist(&*ii); } // Remember function calls @@ -384,10 +413,10 @@ bool AggressiveDCEPass::AggressiveDCE(ir::Function* func) { for (auto& ps : private_stores_) AddToWorklist(ps); // Perform closure on live instruction set. while (!worklist_.empty()) { - ir::Instruction* liveInst = worklist_.front(); + Instruction* liveInst = worklist_.front(); // Add all operand instructions if not already live liveInst->ForEachInId([&liveInst, this](const uint32_t* iid) { - ir::Instruction* inInst = get_def_use_mgr()->GetDef(*iid); + Instruction* inInst = get_def_use_mgr()->GetDef(*iid); // Do not add label if an operand of a branch. This is not needed // as part of live code discovery and can create false live code, // for example, the branch to a header of a loop. @@ -401,15 +430,13 @@ bool AggressiveDCEPass::AggressiveDCE(ir::Function* func) { // conditional branch and its merge. Any containing control construct // is marked live when the merge and branch are processed out of the // worklist. - ir::BasicBlock* blk = context()->get_instr_block(liveInst); - ir::Instruction* branchInst = block2headerBranch_[blk]; + BasicBlock* blk = context()->get_instr_block(liveInst); + Instruction* branchInst = block2headerBranch_[blk]; if (branchInst != nullptr) { AddToWorklist(branchInst); - ir::Instruction* mergeInst = branch2merge_[branchInst]; + Instruction* mergeInst = branch2merge_[branchInst]; AddToWorklist(mergeInst); - // If in a loop, mark all its break and continue instructions live - if (mergeInst->opcode() == SpvOpLoopMerge) - AddBreaksAndContinuesToWorklist(mergeInst); + AddBreaksAndContinuesToWorklist(mergeInst); } // If local load, add all variable's stores if variable not already live if (liveInst->opcode() == SpvOpLoad) { @@ -418,9 +445,16 @@ bool AggressiveDCEPass::AggressiveDCE(ir::Function* func) { if (varId != 0) { ProcessLoad(varId); } - } - // If function call, treat as if it loads from all pointer arguments - else if (liveInst->opcode() == SpvOpFunctionCall) { + } else if (liveInst->opcode() == SpvOpCopyMemory || + liveInst->opcode() == SpvOpCopyMemorySized) { + uint32_t varId; + (void)GetPtr(liveInst->GetSingleWordInOperand(kCopyMemorySourceAddrInIdx), + &varId); + if (varId != 0) { + ProcessLoad(varId); + } + // If function call, treat as if it loads from all pointer arguments + } else if (liveInst->opcode() == SpvOpFunctionCall) { liveInst->ForEachInId([this](const uint32_t* iid) { // Skip non-ptr args if (!IsPtr(*iid)) return; @@ -428,14 +462,12 @@ bool AggressiveDCEPass::AggressiveDCE(ir::Function* func) { (void)GetPtr(*iid, &varId); ProcessLoad(varId); }); - } - // If function parameter, treat as if it's result id is loaded from - else if (liveInst->opcode() == SpvOpFunctionParameter) { + // If function parameter, treat as if it's result id is loaded from + } else if (liveInst->opcode() == SpvOpFunctionParameter) { ProcessLoad(liveInst->result_id()); - } - // We treat an OpImageTexelPointer as a load of the pointer, and - // that value is manipulated to get the result. - else if (liveInst->opcode() == SpvOpImageTexelPointer) { + // We treat an OpImageTexelPointer as a load of the pointer, and + // that value is manipulated to get the result. + } else if (liveInst->opcode() == SpvOpImageTexelPointer) { uint32_t varId; (void)GetPtr(liveInst, &varId); if (varId != 0) { @@ -448,7 +480,7 @@ bool AggressiveDCEPass::AggressiveDCE(ir::Function* func) { // Kill dead instructions and remember dead blocks for (auto bi = structuredOrder.begin(); bi != structuredOrder.end();) { uint32_t mergeBlockId = 0; - (*bi)->ForEachInst([this, &modified, &mergeBlockId](ir::Instruction* inst) { + (*bi)->ForEachInst([this, &modified, &mergeBlockId](Instruction* inst) { if (!IsDead(inst)) return; if (inst->opcode() == SpvOpLabel) return; // If dead instruction is selection merge, remember merge block @@ -474,18 +506,6 @@ bool AggressiveDCEPass::AggressiveDCE(ir::Function* func) { return modified; } -void AggressiveDCEPass::Initialize(ir::IRContext* c) { - InitializeProcessing(c); - - // Clear collections - worklist_ = std::queue{}; - live_insts_.clear(); - live_local_vars_.clear(); - - // Initialize extensions whitelist - InitExtensions(); -} - void AggressiveDCEPass::InitializeModuleScopeLiveInstructions() { // Keep all execution modes. for (auto& exec : get_module()->execution_modes()) { @@ -526,7 +546,7 @@ Pass::Status AggressiveDCEPass::ProcessImpl() { InitializeModuleScopeLiveInstructions(); // Process all entry point functions. - ProcessFunction pfn = [this](ir::Function* fp) { return AggressiveDCE(fp); }; + ProcessFunction pfn = [this](Function* fp) { return AggressiveDCE(fp); }; modified |= ProcessEntryPointCallTree(pfn, get_module()); // Process module-level instructions. Now that all live instructions have @@ -539,7 +559,7 @@ Pass::Status AggressiveDCEPass::ProcessImpl() { } // Cleanup all CFG including all unreachable blocks. - ProcessFunction cleanup = [this](ir::Function* f) { return CFGCleanup(f); }; + ProcessFunction cleanup = [this](Function* f) { return CFGCleanup(f); }; modified |= ProcessEntryPointCallTree(cleanup, get_module()); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; @@ -549,8 +569,8 @@ bool AggressiveDCEPass::EliminateDeadFunctions() { // Identify live functions first. Those that are not live // are dead. ADCE is disabled for non-shaders so we do not check for exported // functions here. - std::unordered_set live_function_set; - ProcessFunction mark_live = [&live_function_set](ir::Function* fp) { + std::unordered_set live_function_set; + ProcessFunction mark_live = [&live_function_set](Function* fp) { live_function_set.insert(fp); return false; }; @@ -571,10 +591,10 @@ bool AggressiveDCEPass::EliminateDeadFunctions() { return modified; } -void AggressiveDCEPass::EliminateFunction(ir::Function* func) { +void AggressiveDCEPass::EliminateFunction(Function* func) { // Remove all of the instruction in the function body - func->ForEachInst( - [this](ir::Instruction* inst) { context()->KillInst(inst); }, true); + func->ForEachInst([this](Instruction* inst) { context()->KillInst(inst); }, + true); } bool AggressiveDCEPass::ProcessGlobalValues() { @@ -582,7 +602,7 @@ bool AggressiveDCEPass::ProcessGlobalValues() { // This must be done before killing the instructions, otherwise there are // dead objects in the def/use database. bool modified = false; - ir::Instruction* instruction = &*get_module()->debug2_begin(); + Instruction* instruction = &*get_module()->debug2_begin(); while (instruction) { if (instruction->opcode() != SpvOpName) { instruction = instruction->NextNode(); @@ -600,7 +620,7 @@ bool AggressiveDCEPass::ProcessGlobalValues() { // This code removes all unnecessary decorations safely (see #1174). It also // does so in a more efficient manner than deleting them only as the targets // are deleted. - std::vector annotations; + std::vector annotations; for (auto& inst : get_module()->annotations()) annotations.push_back(&inst); std::sort(annotations.begin(), annotations.end(), DecorationLess()); for (auto annotation : annotations) { @@ -608,24 +628,32 @@ bool AggressiveDCEPass::ProcessGlobalValues() { case SpvOpDecorate: case SpvOpMemberDecorate: case SpvOpDecorateId: - if (IsTargetDead(annotation)) context()->KillInst(annotation); + case SpvOpDecorateStringGOOGLE: + if (IsTargetDead(annotation)) { + context()->KillInst(annotation); + modified = true; + } break; case SpvOpGroupDecorate: { // Go through the targets of this group decorate. Remove each dead // target. If all targets are dead, remove this decoration. bool dead = true; for (uint32_t i = 1; i < annotation->NumOperands();) { - ir::Instruction* opInst = + Instruction* opInst = get_def_use_mgr()->GetDef(annotation->GetSingleWordOperand(i)); if (IsDead(opInst)) { // Don't increment |i|. annotation->RemoveOperand(i); + modified = true; } else { i++; dead = false; } } - if (dead) context()->KillInst(annotation); + if (dead) { + context()->KillInst(annotation); + modified = true; + } break; } case SpvOpGroupMemberDecorate: { @@ -634,25 +662,31 @@ bool AggressiveDCEPass::ProcessGlobalValues() { // decoration. bool dead = true; for (uint32_t i = 1; i < annotation->NumOperands();) { - ir::Instruction* opInst = + Instruction* opInst = get_def_use_mgr()->GetDef(annotation->GetSingleWordOperand(i)); if (IsDead(opInst)) { // Don't increment |i|. annotation->RemoveOperand(i + 1); annotation->RemoveOperand(i); + modified = true; } else { i += 2; dead = false; } } - if (dead) context()->KillInst(annotation); + if (dead) { + context()->KillInst(annotation); + modified = true; + } break; } case SpvOpDecorationGroup: // By the time we hit decoration groups we've checked everything that // can target them. So if they have no uses they must be dead. - if (get_def_use_mgr()->NumUsers(annotation) == 0) + if (get_def_use_mgr()->NumUsers(annotation) == 0) { context()->KillInst(annotation); + modified = true; + } break; default: assert(false); @@ -671,10 +705,11 @@ bool AggressiveDCEPass::ProcessGlobalValues() { return modified; } -AggressiveDCEPass::AggressiveDCEPass() {} +AggressiveDCEPass::AggressiveDCEPass() = default; -Pass::Status AggressiveDCEPass::Process(ir::IRContext* c) { - Initialize(c); +Pass::Status AggressiveDCEPass::Process() { + // Initialize extensions whitelist + InitExtensions(); return ProcessImpl(); } diff --git a/3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.h b/3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.h index 9c1749b19..3c03cc66b 100644 --- a/3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.h +++ b/3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.h @@ -14,38 +14,42 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_AGGRESSIVE_DCE_PASS_H_ -#define LIBSPIRV_OPT_AGGRESSIVE_DCE_PASS_H_ +#ifndef SOURCE_OPT_AGGRESSIVE_DEAD_CODE_ELIM_PASS_H_ +#define SOURCE_OPT_AGGRESSIVE_DEAD_CODE_ELIM_PASS_H_ #include +#include #include #include +#include #include #include #include +#include -#include "basic_block.h" -#include "def_use_manager.h" -#include "mem_pass.h" -#include "module.h" +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" +#include "source/util/bit_vector.h" namespace spvtools { namespace opt { // See optimizer.hpp for documentation. class AggressiveDCEPass : public MemPass { - using cbb_ptr = const ir::BasicBlock*; + using cbb_ptr = const BasicBlock*; public: using GetBlocksFunction = - std::function*(const ir::BasicBlock*)>; + std::function*(const BasicBlock*)>; AggressiveDCEPass(); const char* name() const override { return "eliminate-dead-code-aggressive"; } - Status Process(ir::IRContext* c) override; + Status Process() override; - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse; + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping; } private: @@ -59,20 +63,22 @@ class AggressiveDCEPass : public MemPass { bool IsLocalVar(uint32_t varId); // Return true if |inst| is marked live. - bool IsLive(const ir::Instruction* inst) const { - return live_insts_.find(inst) != live_insts_.end(); + bool IsLive(const Instruction* inst) const { + return live_insts_.Get(inst->unique_id()); } // Returns true if |inst| is dead. - bool IsDead(ir::Instruction* inst); + bool IsDead(Instruction* inst); // Adds entry points, execution modes and workgroup size decorations to the // worklist for processing with the first function. void InitializeModuleScopeLiveInstructions(); // Add |inst| to worklist_ and live_insts_. - void AddToWorklist(ir::Instruction* inst) { - if (live_insts_.insert(inst).second) worklist_.push(inst); + void AddToWorklist(Instruction* inst) { + if (!live_insts_.Set(inst->unique_id())) { + worklist_.push(inst); + } } // Add all store instruction which use |ptrId|, directly or indirectly, @@ -88,7 +94,7 @@ class AggressiveDCEPass : public MemPass { // Returns true if the target of |inst| is dead. An instruction is dead if // its result id is used in decoration or debug instructions only. |inst| is // assumed to be OpName, OpMemberName or an annotation instruction. - bool IsTargetDead(ir::Instruction* inst); + bool IsTargetDead(Instruction* inst); // If |varId| is local, mark all stores of varId as live. void ProcessLoad(uint32_t varId); @@ -98,19 +104,19 @@ class AggressiveDCEPass : public MemPass { // merge block if they are not nullptr. Any of |mergeInst|, |branchInst| or // |mergeBlockId| may be a null pointer. Returns false if |bp| is a null // pointer. - bool IsStructuredHeader(ir::BasicBlock* bp, ir::Instruction** mergeInst, - ir::Instruction** branchInst, uint32_t* mergeBlockId); + bool IsStructuredHeader(BasicBlock* bp, Instruction** mergeInst, + Instruction** branchInst, uint32_t* mergeBlockId); // Initialize block2headerBranch_ and branch2merge_ using |structuredOrder| // to order blocks. - void ComputeBlock2HeaderMaps(std::list& structuredOrder); + void ComputeBlock2HeaderMaps(std::list& structuredOrder); // Add branch to |labelId| to end of block |bp|. - void AddBranch(uint32_t labelId, ir::BasicBlock* bp); + void AddBranch(uint32_t labelId, BasicBlock* bp); - // Add all break and continue branches in the loop associated with + // Add all break and continue branches in the construct associated with // |mergeInst| to worklist if not already live - void AddBreaksAndContinuesToWorklist(ir::Instruction* mergeInst); + void AddBreaksAndContinuesToWorklist(Instruction* mergeInst); // Eliminates dead debug2 and annotation instructions. Marks dead globals for // removal (e.g. types, constants and variables). @@ -120,7 +126,7 @@ class AggressiveDCEPass : public MemPass { bool EliminateDeadFunctions(); // Removes |func| from the module and deletes all its instructions. - void EliminateFunction(ir::Function* func); + void EliminateFunction(Function* func); // For function |func|, mark all Stores to non-function-scope variables // and block terminating instructions as live. Recursively mark the values @@ -131,9 +137,8 @@ class AggressiveDCEPass : public MemPass { // existing control structures will remain. This can leave not-insignificant // sequences of ultimately useless code. // TODO(): Remove useless control constructs. - bool AggressiveDCE(ir::Function* func); + bool AggressiveDCE(Function* func); - void Initialize(ir::IRContext* c); Pass::Status ProcessImpl(); // True if current function has a call instruction contained in it @@ -150,32 +155,32 @@ class AggressiveDCEPass : public MemPass { // If we don't know, then add it to this list. Instructions are // removed from this list as the algorithm traces side effects, // building up the live instructions set |live_insts_|. - std::queue worklist_; + std::queue worklist_; // Map from block to the branch instruction in the header of the most // immediate controlling structured if or loop. A loop header block points // to its own branch instruction. An if-selection block points to the branch // of an enclosing construct's header, if one exists. - std::unordered_map block2headerBranch_; + std::unordered_map block2headerBranch_; // Maps basic block to their index in the structured order traversal. - std::unordered_map structured_order_index_; + std::unordered_map structured_order_index_; // Map from branch to its associated merge instruction, if any - std::unordered_map branch2merge_; + std::unordered_map branch2merge_; // Store instructions to variables of private storage - std::vector private_stores_; + std::vector private_stores_; // Live Instructions - std::unordered_set live_insts_; + utils::BitVector live_insts_; // Live Local Variables std::unordered_set live_local_vars_; // List of instructions to delete. Deletion is delayed until debug and // annotation instructions are processed. - std::vector to_kill_; + std::vector to_kill_; // Extensions supported by this pass. std::unordered_set extensions_whitelist_; @@ -184,4 +189,4 @@ class AggressiveDCEPass : public MemPass { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_AGGRESSIVE_DCE_PASS_H_ +#endif // SOURCE_OPT_AGGRESSIVE_DEAD_CODE_ELIM_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/basic_block.cpp b/3rdparty/spirv-tools/source/opt/basic_block.cpp index 98b069585..b18b114a5 100644 --- a/3rdparty/spirv-tools/source/opt/basic_block.cpp +++ b/3rdparty/spirv-tools/source/opt/basic_block.cpp @@ -12,19 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "basic_block.h" -#include "function.h" -#include "ir_context.h" -#include "module.h" -#include "reflect.h" - -#include "make_unique.h" +#include "source/opt/basic_block.h" #include -namespace spvtools { -namespace ir { +#include "source/opt/function.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/reflect.h" +#include "source/util/make_unique.h" +namespace spvtools { +namespace opt { namespace { const uint32_t kLoopMergeContinueBlockIdInIdx = 1; @@ -91,7 +90,7 @@ Instruction* BasicBlock::GetLoopMergeInst() { } void BasicBlock::KillAllInsts(bool killLabel) { - ForEachInst([killLabel](ir::Instruction* ip) { + ForEachInst([killLabel](Instruction* ip) { if (killLabel || ip->opcode() != SpvOpLabel) { ip->context()->KillInst(ip); } @@ -140,7 +139,7 @@ void BasicBlock::ForEachSuccessorLabel( } } -bool BasicBlock::IsSuccessor(const ir::BasicBlock* block) const { +bool BasicBlock::IsSuccessor(const BasicBlock* block) const { uint32_t succId = block->id(); bool isSuccessor = false; ForEachSuccessorLabel([&isSuccessor, succId](const uint32_t label) { @@ -196,7 +195,7 @@ std::ostream& operator<<(std::ostream& str, const BasicBlock& block) { std::string BasicBlock::PrettyPrint(uint32_t options) const { std::ostringstream str; - ForEachInst([&str, options](const ir::Instruction* inst) { + ForEachInst([&str, options](const Instruction* inst) { str << inst->PrettyPrint(options); if (!IsTerminatorInst(inst->opcode())) { str << std::endl; @@ -210,13 +209,36 @@ BasicBlock* BasicBlock::SplitBasicBlock(IRContext* context, uint32_t label_id, assert(!insts_.empty()); BasicBlock* new_block = new BasicBlock(MakeUnique( - context, SpvOpLabel, 0, label_id, std::initializer_list{})); + context, SpvOpLabel, 0, label_id, std::initializer_list{})); new_block->insts_.Splice(new_block->end(), &insts_, iter, end()); new_block->SetParent(GetParent()); - if (context->AreAnalysesValid(ir::IRContext::kAnalysisInstrToBlockMapping)) { - new_block->ForEachInst([new_block, context](ir::Instruction* inst) { + context->AnalyzeDefUse(new_block->GetLabelInst()); + + // Update the phi nodes in the successor blocks to reference the new block id. + const_cast(new_block)->ForEachSuccessorLabel( + [new_block, this, context](const uint32_t label) { + BasicBlock* target_bb = context->get_instr_block(label); + target_bb->ForEachPhiInst( + [this, new_block, context](Instruction* phi_inst) { + bool changed = false; + for (uint32_t i = 1; i < phi_inst->NumInOperands(); i += 2) { + if (phi_inst->GetSingleWordInOperand(i) == this->id()) { + changed = true; + phi_inst->SetInOperand(i, {new_block->id()}); + } + } + + if (changed) { + context->UpdateDefUse(phi_inst); + } + }); + }); + + if (context->AreAnalysesValid(IRContext::kAnalysisInstrToBlockMapping)) { + context->set_instr_block(new_block->GetLabelInst(), new_block); + new_block->ForEachInst([new_block, context](Instruction* inst) { context->set_instr_block(inst, new_block); }); } @@ -224,5 +246,5 @@ BasicBlock* BasicBlock::SplitBasicBlock(IRContext* context, uint32_t label_id, return new_block; } -} // namespace ir +} // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/basic_block.h b/3rdparty/spirv-tools/source/opt/basic_block.h index 0df1f76ae..9e1706e14 100644 --- a/3rdparty/spirv-tools/source/opt/basic_block.h +++ b/3rdparty/spirv-tools/source/opt/basic_block.h @@ -15,21 +15,23 @@ // This file defines the language constructs for representing a SPIR-V // module in memory. -#ifndef LIBSPIRV_OPT_BASIC_BLOCK_H_ -#define LIBSPIRV_OPT_BASIC_BLOCK_H_ +#ifndef SOURCE_OPT_BASIC_BLOCK_H_ +#define SOURCE_OPT_BASIC_BLOCK_H_ #include +#include #include #include +#include #include #include -#include "instruction.h" -#include "instruction_list.h" -#include "iterator.h" +#include "source/opt/instruction.h" +#include "source/opt/instruction_list.h" +#include "source/opt/iterator.h" namespace spvtools { -namespace ir { +namespace opt { class Function; class IRContext; @@ -39,6 +41,9 @@ class BasicBlock { public: using iterator = InstructionList::iterator; using const_iterator = InstructionList::const_iterator; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = + std::reverse_iterator; // Creates a basic block with the given starting |label|. inline explicit BasicBlock(std::unique_ptr label); @@ -87,6 +92,21 @@ class BasicBlock { const_iterator cbegin() const { return insts_.cbegin(); } const_iterator cend() const { return insts_.cend(); } + reverse_iterator rbegin() { return reverse_iterator(end()); } + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(cend()); + } + const_reverse_iterator rend() const { + return const_reverse_iterator(cbegin()); + } + const_reverse_iterator crbegin() const { + return const_reverse_iterator(cend()); + } + const_reverse_iterator crend() const { + return const_reverse_iterator(cbegin()); + } + // Returns an iterator pointing to the last instruction. This may only // be used if this block has an instruction other than the OpLabel // that defines it. @@ -140,14 +160,14 @@ class BasicBlock { void ForEachSuccessorLabel(const std::function& f); // Returns true if |block| is a direct successor of |this|. - bool IsSuccessor(const ir::BasicBlock* block) const; + bool IsSuccessor(const BasicBlock* block) const; // Runs the given function |f| on the merge and continue label, if any void ForMergeAndContinueLabel(const std::function& f); // Returns true if this basic block has any Phi instructions. bool HasPhiInstructions() { - return !WhileEachPhiInst([](ir::Instruction*) { return false; }); + return !WhileEachPhiInst([](Instruction*) { return false; }); } // Return true if this block is a loop header block. @@ -294,7 +314,7 @@ inline void BasicBlock::ForEachPhiInst( run_on_debug_line_insts); } -} // namespace ir +} // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_BASIC_BLOCK_H_ +#endif // SOURCE_OPT_BASIC_BLOCK_H_ diff --git a/3rdparty/spirv-tools/source/opt/block_merge_pass.cpp b/3rdparty/spirv-tools/source/opt/block_merge_pass.cpp index fb370d690..aa4c1bd92 100644 --- a/3rdparty/spirv-tools/source/opt/block_merge_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/block_merge_pass.cpp @@ -14,17 +14,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "block_merge_pass.h" +#include "source/opt/block_merge_pass.h" -#include "ir_context.h" -#include "iterator.h" +#include + +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" namespace spvtools { namespace opt { -void BlockMergePass::KillInstAndName(ir::Instruction* inst) { - std::vector to_kill; - get_def_use_mgr()->ForEachUser(inst, [&to_kill](ir::Instruction* user) { +void BlockMergePass::KillInstAndName(Instruction* inst) { + std::vector to_kill; + get_def_use_mgr()->ForEachUser(inst, [&to_kill](Instruction* user) { if (user->opcode() == SpvOpName) { to_kill.push_back(user); } @@ -35,13 +37,13 @@ void BlockMergePass::KillInstAndName(ir::Instruction* inst) { context()->KillInst(inst); } -bool BlockMergePass::MergeBlocks(ir::Function* func) { +bool BlockMergePass::MergeBlocks(Function* func) { bool modified = false; for (auto bi = func->begin(); bi != func->end();) { // Find block with single successor which has no other predecessors. auto ii = bi->end(); --ii; - ir::Instruction* br = &*ii; + Instruction* br = &*ii; if (br->opcode() != SpvOpBranch) { ++bi; continue; @@ -69,14 +71,14 @@ bool BlockMergePass::MergeBlocks(ir::Function* func) { continue; } - ir::Instruction* merge_inst = bi->GetMergeInst(); + Instruction* merge_inst = bi->GetMergeInst(); if (pred_is_header && lab_id != merge_inst->GetSingleWordInOperand(0u)) { // If this is a header block and the successor is not its merge, we must // be careful about which blocks we are willing to merge together. // OpLoopMerge must be followed by a conditional or unconditional branch. // The merge must be a loop merge because a selection merge cannot be // followed by an unconditional branch. - ir::BasicBlock* succ_block = context()->get_instr_block(lab_id); + BasicBlock* succ_block = context()->get_instr_block(lab_id); SpvOp succ_term_op = succ_block->terminator()->opcode(); assert(merge_inst->opcode() == SpvOpLoopMerge); if (succ_term_op != SpvOpBranch && @@ -94,7 +96,15 @@ bool BlockMergePass::MergeBlocks(ir::Function* func) { // If bi is sbi's only predecessor, it dominates sbi and thus // sbi must follow bi in func's ordering. assert(sbi != func->end()); + + // Update the inst-to-block mapping for the instructions in sbi. + for (auto& inst : *sbi) { + context()->set_instr_block(&inst, &*bi); + } + + // Now actually move the instructions. bi->AddInstructions(&*sbi); + if (merge_inst) { if (pred_is_header && lab_id == merge_inst->GetSingleWordInOperand(0u)) { // Merging the header and merge blocks, so remove the structured control @@ -114,7 +124,7 @@ bool BlockMergePass::MergeBlocks(ir::Function* func) { return modified; } -bool BlockMergePass::IsHeader(ir::BasicBlock* block) { +bool BlockMergePass::IsHeader(BasicBlock* block) { return block->GetMergeInst() != nullptr; } @@ -123,7 +133,7 @@ bool BlockMergePass::IsHeader(uint32_t id) { } bool BlockMergePass::IsMerge(uint32_t id) { - return !get_def_use_mgr()->WhileEachUse(id, [](ir::Instruction* user, + return !get_def_use_mgr()->WhileEachUse(id, [](Instruction* user, uint32_t index) { SpvOp op = user->opcode(); if ((op == SpvOpLoopMerge || op == SpvOpSelectionMerge) && index == 0u) { @@ -133,25 +143,16 @@ bool BlockMergePass::IsMerge(uint32_t id) { }); } -bool BlockMergePass::IsMerge(ir::BasicBlock* block) { - return IsMerge(block->id()); -} +bool BlockMergePass::IsMerge(BasicBlock* block) { return IsMerge(block->id()); } -void BlockMergePass::Initialize(ir::IRContext* c) { InitializeProcessing(c); } - -Pass::Status BlockMergePass::ProcessImpl() { +Pass::Status BlockMergePass::Process() { // Process all entry point functions. - ProcessFunction pfn = [this](ir::Function* fp) { return MergeBlocks(fp); }; + ProcessFunction pfn = [this](Function* fp) { return MergeBlocks(fp); }; bool modified = ProcessEntryPointCallTree(pfn, get_module()); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } -BlockMergePass::BlockMergePass() {} - -Pass::Status BlockMergePass::Process(ir::IRContext* c) { - Initialize(c); - return ProcessImpl(); -} +BlockMergePass::BlockMergePass() = default; } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/block_merge_pass.h b/3rdparty/spirv-tools/source/opt/block_merge_pass.h index a56f245e1..0ecde4884 100644 --- a/3rdparty/spirv-tools/source/opt/block_merge_pass.h +++ b/3rdparty/spirv-tools/source/opt/block_merge_pass.h @@ -14,8 +14,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_BLOCK_MERGE_PASS_H_ -#define LIBSPIRV_OPT_BLOCK_MERGE_PASS_H_ +#ifndef SOURCE_OPT_BLOCK_MERGE_PASS_H_ +#define SOURCE_OPT_BLOCK_MERGE_PASS_H_ #include #include @@ -24,11 +24,11 @@ #include #include -#include "basic_block.h" -#include "def_use_manager.h" -#include "ir_context.h" -#include "module.h" -#include "pass.h" +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -38,30 +38,34 @@ class BlockMergePass : public Pass { public: BlockMergePass(); const char* name() const override { return "merge-blocks"; } - Status Process(ir::IRContext*) override; + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisNameMap; + } private: // Kill any OpName instruction referencing |inst|, then kill |inst|. - void KillInstAndName(ir::Instruction* inst); + void KillInstAndName(Instruction* inst); // Search |func| for blocks which have a single Branch to a block // with no other predecessors. Merge these blocks into a single block. - bool MergeBlocks(ir::Function* func); + bool MergeBlocks(Function* func); // Returns true if |block| (or |id|) contains a merge instruction. - bool IsHeader(ir::BasicBlock* block); + bool IsHeader(BasicBlock* block); bool IsHeader(uint32_t id); // Returns true if |block| (or |id|) is the merge target of a merge // instruction. - bool IsMerge(ir::BasicBlock* block); + bool IsMerge(BasicBlock* block); bool IsMerge(uint32_t id); - - void Initialize(ir::IRContext* c); - Pass::Status ProcessImpl(); }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_BLOCK_MERGE_PASS_H_ +#endif // SOURCE_OPT_BLOCK_MERGE_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/build_module.cpp b/3rdparty/spirv-tools/source/opt/build_module.cpp index c441fcccd..fc76a3c29 100644 --- a/3rdparty/spirv-tools/source/opt/build_module.cpp +++ b/3rdparty/spirv-tools/source/opt/build_module.cpp @@ -12,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "build_module.h" +#include "source/opt/build_module.h" -#include "ir_context.h" -#include "ir_loader.h" -#include "make_unique.h" -#include "table.h" +#include +#include + +#include "source/opt/ir_context.h" +#include "source/opt/ir_loader.h" +#include "source/table.h" +#include "source/util/make_unique.h" namespace spvtools { - namespace { // Sets the module header for IrLoader. Meets the interface requirement of @@ -28,7 +30,7 @@ namespace { spv_result_t SetSpvHeader(void* builder, spv_endianness_t, uint32_t magic, uint32_t version, uint32_t generator, uint32_t id_bound, uint32_t reserved) { - reinterpret_cast(builder)->SetModuleHeader( + reinterpret_cast(builder)->SetModuleHeader( magic, version, generator, id_bound, reserved); return SPV_SUCCESS; } @@ -36,7 +38,7 @@ spv_result_t SetSpvHeader(void* builder, spv_endianness_t, uint32_t magic, // Processes a parsed instruction for IrLoader. Meets the interface requirement // of spvBinaryParse(). spv_result_t SetSpvInst(void* builder, const spv_parsed_instruction_t* inst) { - if (reinterpret_cast(builder)->AddInstruction(inst)) { + if (reinterpret_cast(builder)->AddInstruction(inst)) { return SPV_SUCCESS; } return SPV_ERROR_INVALID_BINARY; @@ -44,15 +46,15 @@ spv_result_t SetSpvInst(void* builder, const spv_parsed_instruction_t* inst) { } // namespace -std::unique_ptr BuildModule(spv_target_env env, - MessageConsumer consumer, - const uint32_t* binary, - const size_t size) { +std::unique_ptr BuildModule(spv_target_env env, + MessageConsumer consumer, + const uint32_t* binary, + const size_t size) { auto context = spvContextCreate(env); - libspirv::SetContextMessageConsumer(context, consumer); + SetContextMessageConsumer(context, consumer); - auto irContext = MakeUnique(env, consumer); - ir::IrLoader loader(consumer, irContext->module()); + auto irContext = MakeUnique(env, consumer); + opt::IrLoader loader(consumer, irContext->module()); spv_result_t status = spvBinaryParse(context, &loader, binary, size, SetSpvHeader, SetSpvInst, nullptr); @@ -63,10 +65,10 @@ std::unique_ptr BuildModule(spv_target_env env, return status == SPV_SUCCESS ? std::move(irContext) : nullptr; } -std::unique_ptr BuildModule(spv_target_env env, - MessageConsumer consumer, - const std::string& text, - uint32_t assemble_options) { +std::unique_ptr BuildModule(spv_target_env env, + MessageConsumer consumer, + const std::string& text, + uint32_t assemble_options) { SpirvTools t(env); t.SetMessageConsumer(consumer); std::vector binary; diff --git a/3rdparty/spirv-tools/source/opt/build_module.h b/3rdparty/spirv-tools/source/opt/build_module.h index 7e9493634..c9d1cf2e4 100644 --- a/3rdparty/spirv-tools/source/opt/build_module.h +++ b/3rdparty/spirv-tools/source/opt/build_module.h @@ -12,34 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SPIRV_TOOLS_OPT_BUILD_MODULE_H_ -#define SPIRV_TOOLS_OPT_BUILD_MODULE_H_ +#ifndef SOURCE_OPT_BUILD_MODULE_H_ +#define SOURCE_OPT_BUILD_MODULE_H_ #include #include -#include "ir_context.h" -#include "module.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" #include "spirv-tools/libspirv.hpp" namespace spvtools { -// Builds an ir::Module returns the owning ir::IRContext from the given SPIR-V +// Builds an Module returns the owning IRContext from the given SPIR-V // |binary|. |size| specifies number of words in |binary|. The |binary| will be // decoded according to the given target |env|. Returns nullptr if errors occur // and sends the errors to |consumer|. -std::unique_ptr BuildModule(spv_target_env env, - MessageConsumer consumer, - const uint32_t* binary, size_t size); +std::unique_ptr BuildModule(spv_target_env env, + MessageConsumer consumer, + const uint32_t* binary, + size_t size); -// Builds an ir::Module and returns the owning ir::IRContext from the given +// Builds an Module and returns the owning IRContext from the given // SPIR-V assembly |text|. The |text| will be encoded according to the given // target |env|. Returns nullptr if errors occur and sends the errors to // |consumer|. -std::unique_ptr BuildModule( +std::unique_ptr BuildModule( spv_target_env env, MessageConsumer consumer, const std::string& text, uint32_t assemble_options = SpirvTools::kDefaultAssembleOption); } // namespace spvtools -#endif // SPIRV_TOOLS_OPT_BUILD_MODULE_H_ +#endif // SOURCE_OPT_BUILD_MODULE_H_ diff --git a/3rdparty/spirv-tools/source/opt/ccp_pass.cpp b/3rdparty/spirv-tools/source/opt/ccp_pass.cpp index e9044ad41..a8411d9fe 100644 --- a/3rdparty/spirv-tools/source/opt/ccp_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/ccp_pass.cpp @@ -16,13 +16,16 @@ // // Constant propagation with conditional branches, // Wegman and Zadeck, ACM TOPLAS 13(2):181-210. -#include "ccp_pass.h" -#include "fold.h" -#include "function.h" -#include "module.h" -#include "propagator.h" + +#include "source/opt/ccp_pass.h" #include +#include + +#include "source/opt/fold.h" +#include "source/opt/function.h" +#include "source/opt/module.h" +#include "source/opt/propagator.h" namespace spvtools { namespace opt { @@ -38,15 +41,14 @@ const uint32_t kVaryingSSAId = std::numeric_limits::max(); bool CCPPass::IsVaryingValue(uint32_t id) const { return id == kVaryingSSAId; } -SSAPropagator::PropStatus CCPPass::MarkInstructionVarying( - ir::Instruction* instr) { +SSAPropagator::PropStatus CCPPass::MarkInstructionVarying(Instruction* instr) { assert(instr->result_id() != 0 && "Instructions with no result cannot be marked varying."); values_[instr->result_id()] = kVaryingSSAId; return SSAPropagator::kVarying; } -SSAPropagator::PropStatus CCPPass::VisitPhi(ir::Instruction* phi) { +SSAPropagator::PropStatus CCPPass::VisitPhi(Instruction* phi) { uint32_t meet_val_id = 0; // Implement the lattice meet operation. The result of this Phi instruction is @@ -100,7 +102,7 @@ SSAPropagator::PropStatus CCPPass::VisitPhi(ir::Instruction* phi) { return SSAPropagator::kInteresting; } -SSAPropagator::PropStatus CCPPass::VisitAssignment(ir::Instruction* instr) { +SSAPropagator::PropStatus CCPPass::VisitAssignment(Instruction* instr) { assert(instr->result_id() != 0 && "Expecting an instruction that produces a result"); @@ -133,8 +135,9 @@ SSAPropagator::PropStatus CCPPass::VisitAssignment(ir::Instruction* instr) { } return it->second; }; - ir::Instruction* folded_inst = - opt::FoldInstructionToConstant(instr, map_func); + Instruction* folded_inst = + context()->get_instruction_folder().FoldInstructionToConstant(instr, + map_func); if (folded_inst != nullptr) { // We do not want to change the body of the function by adding new // instructions. When folding we can only generate new constants. @@ -167,8 +170,8 @@ SSAPropagator::PropStatus CCPPass::VisitAssignment(ir::Instruction* instr) { return MarkInstructionVarying(instr); } -SSAPropagator::PropStatus CCPPass::VisitBranch(ir::Instruction* instr, - ir::BasicBlock** dest_bb) const { +SSAPropagator::PropStatus CCPPass::VisitBranch(Instruction* instr, + BasicBlock** dest_bb) const { assert(instr->IsBranch() && "Expected a branch instruction."); *dest_bb = nullptr; @@ -249,8 +252,8 @@ SSAPropagator::PropStatus CCPPass::VisitBranch(ir::Instruction* instr, return SSAPropagator::kInteresting; } -SSAPropagator::PropStatus CCPPass::VisitInstruction(ir::Instruction* instr, - ir::BasicBlock** dest_bb) { +SSAPropagator::PropStatus CCPPass::VisitInstruction(Instruction* instr, + BasicBlock** dest_bb) { *dest_bb = nullptr; if (instr->opcode() == SpvOpPhi) { return VisitPhi(instr); @@ -274,14 +277,13 @@ bool CCPPass::ReplaceValues() { return retval; } -bool CCPPass::PropagateConstants(ir::Function* fp) { +bool CCPPass::PropagateConstants(Function* fp) { // Mark function parameters as varying. - fp->ForEachParam([this](const ir::Instruction* inst) { + fp->ForEachParam([this](const Instruction* inst) { values_[inst->result_id()] = kVaryingSSAId; }); - const auto visit_fn = [this](ir::Instruction* instr, - ir::BasicBlock** dest_bb) { + const auto visit_fn = [this](Instruction* instr, BasicBlock** dest_bb) { return VisitInstruction(instr, dest_bb); }; @@ -295,9 +297,7 @@ bool CCPPass::PropagateConstants(ir::Function* fp) { return false; } -void CCPPass::Initialize(ir::IRContext* c) { - InitializeProcessing(c); - +void CCPPass::Initialize() { const_mgr_ = context()->get_constant_mgr(); // Populate the constant table with values from constant declarations in the @@ -314,13 +314,11 @@ void CCPPass::Initialize(ir::IRContext* c) { } } -Pass::Status CCPPass::Process(ir::IRContext* c) { - Initialize(c); +Pass::Status CCPPass::Process() { + Initialize(); // Process all entry point functions. - ProcessFunction pfn = [this](ir::Function* fp) { - return PropagateConstants(fp); - }; + ProcessFunction pfn = [this](Function* fp) { return PropagateConstants(fp); }; bool modified = ProcessReachableCallTree(pfn, context()); return modified ? Pass::Status::SuccessWithChange : Pass::Status::SuccessWithoutChange; diff --git a/3rdparty/spirv-tools/source/opt/ccp_pass.h b/3rdparty/spirv-tools/source/opt/ccp_pass.h index e96e731a7..178fd1281 100644 --- a/3rdparty/spirv-tools/source/opt/ccp_pass.h +++ b/3rdparty/spirv-tools/source/opt/ccp_pass.h @@ -12,15 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_CCP_PASS_H_ -#define LIBSPIRV_OPT_CCP_PASS_H_ +#ifndef SOURCE_OPT_CCP_PASS_H_ +#define SOURCE_OPT_CCP_PASS_H_ -#include "constants.h" -#include "function.h" -#include "ir_context.h" -#include "mem_pass.h" -#include "module.h" -#include "propagator.h" +#include +#include + +#include "source/opt/constants.h" +#include "source/opt/function.h" +#include "source/opt/ir_context.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" +#include "source/opt/propagator.h" namespace spvtools { namespace opt { @@ -28,39 +31,48 @@ namespace opt { class CCPPass : public MemPass { public: CCPPass() = default; + const char* name() const override { return "ccp"; } - Status Process(ir::IRContext* c) override; + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap; + } private: // Initializes the pass. - void Initialize(ir::IRContext* c); + void Initialize(); // Runs constant propagation on the given function |fp|. Returns true if any // constants were propagated and the IR modified. - bool PropagateConstants(ir::Function* fp); + bool PropagateConstants(Function* fp); // Visits a single instruction |instr|. If the instruction is a conditional // branch that always jumps to the same basic block, it sets the destination // block in |dest_bb|. - SSAPropagator::PropStatus VisitInstruction(ir::Instruction* instr, - ir::BasicBlock** dest_bb); + SSAPropagator::PropStatus VisitInstruction(Instruction* instr, + BasicBlock** dest_bb); // Visits an OpPhi instruction |phi|. This applies the meet operator for the // CCP lattice. Essentially, if all the operands in |phi| have the same // constant value C, the result for |phi| gets assigned the value C. - SSAPropagator::PropStatus VisitPhi(ir::Instruction* phi); + SSAPropagator::PropStatus VisitPhi(Instruction* phi); // Visits an SSA assignment instruction |instr|. If the RHS of |instr| folds // into a constant value C, then the LHS of |instr| is assigned the value C in // |values_|. - SSAPropagator::PropStatus VisitAssignment(ir::Instruction* instr); + SSAPropagator::PropStatus VisitAssignment(Instruction* instr); // Visits a branch instruction |instr|. If the branch is conditional // (OpBranchConditional or OpSwitch), and the value of its selector is known, // |dest_bb| will be set to the corresponding destination block. Unconditional // branches always set |dest_bb| to the single destination block. - SSAPropagator::PropStatus VisitBranch(ir::Instruction* instr, - ir::BasicBlock** dest_bb) const; + SSAPropagator::PropStatus VisitBranch(Instruction* instr, + BasicBlock** dest_bb) const; // Replaces all operands used in |fp| with the corresponding constant values // in |values_|. Returns true if any operands were replaced, and false @@ -69,7 +81,7 @@ class CCPPass : public MemPass { // Marks |instr| as varying by registering a varying value for its result // into the |values_| table. Returns SSAPropagator::kVarying. - SSAPropagator::PropStatus MarkInstructionVarying(ir::Instruction* instr); + SSAPropagator::PropStatus MarkInstructionVarying(Instruction* instr); // Returns true if |id| is the special SSA id that corresponds to a varying // value. @@ -97,4 +109,4 @@ class CCPPass : public MemPass { } // namespace opt } // namespace spvtools -#endif +#endif // SOURCE_OPT_CCP_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/cfg.cpp b/3rdparty/spirv-tools/source/opt/cfg.cpp index 6767570a2..dcf2b573f 100644 --- a/3rdparty/spirv-tools/source/opt/cfg.cpp +++ b/3rdparty/spirv-tools/source/opt/cfg.cpp @@ -12,28 +12,33 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "cfg.h" -#include "cfa.h" -#include "ir_builder.h" -#include "ir_context.h" -#include "module.h" +#include "source/opt/cfg.h" + +#include +#include + +#include "source/cfa.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" namespace spvtools { -namespace ir { - +namespace opt { namespace { +using cbb_ptr = const opt::BasicBlock*; + // Universal Limit of ResultID + 1 -const int kInvalidId = 0x400000; +const int kMaxResultId = 0x400000; } // namespace -CFG::CFG(ir::Module* module) +CFG::CFG(Module* module) : module_(module), - pseudo_entry_block_(std::unique_ptr( - new ir::Instruction(module->context(), SpvOpLabel, 0, 0, {}))), - pseudo_exit_block_(std::unique_ptr(new ir::Instruction( - module->context(), SpvOpLabel, 0, kInvalidId, {}))) { + pseudo_entry_block_(std::unique_ptr( + new Instruction(module->context(), SpvOpLabel, 0, 0, {}))), + pseudo_exit_block_(std::unique_ptr(new Instruction( + module->context(), SpvOpLabel, 0, kMaxResultId, {}))) { for (auto& fn : *module) { for (auto& blk : fn) { RegisterBlock(&blk); @@ -41,7 +46,7 @@ CFG::CFG(ir::Module* module) } } -void CFG::AddEdges(ir::BasicBlock* blk) { +void CFG::AddEdges(BasicBlock* blk) { uint32_t blk_id = blk->id(); // Force the creation of an entry, not all basic block have predecessors // (such as the entry blocks and some unreachables). @@ -54,7 +59,7 @@ void CFG::AddEdges(ir::BasicBlock* blk) { void CFG::RemoveNonExistingEdges(uint32_t blk_id) { std::vector updated_pred_list; for (uint32_t id : preds(blk_id)) { - const ir::BasicBlock* pred_blk = block(id); + const BasicBlock* pred_blk = block(id); bool has_branch = false; pred_blk->ForEachSuccessorLabel([&has_branch, blk_id](uint32_t succ) { if (succ == blk_id) { @@ -67,8 +72,8 @@ void CFG::RemoveNonExistingEdges(uint32_t blk_id) { label2preds_.at(blk_id) = std::move(updated_pred_list); } -void CFG::ComputeStructuredOrder(ir::Function* func, ir::BasicBlock* root, - std::list* order) { +void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root, + std::list* order) { assert(module_->context()->get_feature_mgr()->HasCapability( SpvCapabilityShader) && "This only works on structured control flow"); @@ -77,17 +82,30 @@ void CFG::ComputeStructuredOrder(ir::Function* func, ir::BasicBlock* root, ComputeStructuredSuccessors(func); auto ignore_block = [](cbb_ptr) {}; auto ignore_edge = [](cbb_ptr, cbb_ptr) {}; - auto get_structured_successors = [this](const ir::BasicBlock* b) { + auto get_structured_successors = [this](const BasicBlock* b) { return &(block2structured_succs_[b]); }; // TODO(greg-lunarg): Get rid of const_cast by making moving const // out of the cfa.h prototypes and into the invoking code. auto post_order = [&](cbb_ptr b) { - order->push_front(const_cast(b)); + order->push_front(const_cast(b)); }; - spvtools::CFA::DepthFirstTraversal( - root, get_structured_successors, ignore_block, post_order, ignore_edge); + CFA::DepthFirstTraversal(root, get_structured_successors, + ignore_block, post_order, ignore_edge); +} + +void CFG::ForEachBlockInPostOrder(BasicBlock* bb, + const std::function& f) { + std::vector po; + std::unordered_set seen; + ComputePostOrderTraversal(bb, &po, &seen); + + for (BasicBlock* current_bb : po) { + if (!IsPseudoExitBlock(current_bb) && !IsPseudoEntryBlock(current_bb)) { + f(current_bb); + } + } } void CFG::ForEachBlockInReversePostOrder( @@ -103,7 +121,7 @@ void CFG::ForEachBlockInReversePostOrder( } } -void CFG::ComputeStructuredSuccessors(ir::Function* func) { +void CFG::ComputeStructuredSuccessors(Function* func) { block2structured_succs_.clear(); for (auto& blk : *func) { // If no predecessors in function, make successor to pseudo entry. @@ -129,8 +147,9 @@ void CFG::ComputeStructuredSuccessors(ir::Function* func) { } } -void CFG::ComputePostOrderTraversal(BasicBlock* bb, vector* order, - unordered_set* seen) { +void CFG::ComputePostOrderTraversal(BasicBlock* bb, + std::vector* order, + std::unordered_set* seen) { seen->insert(bb); static_cast(bb)->ForEachSuccessorLabel( [&order, &seen, this](const uint32_t sbid) { @@ -142,11 +161,11 @@ void CFG::ComputePostOrderTraversal(BasicBlock* bb, vector* order, order->push_back(bb); } -BasicBlock* CFG::SplitLoopHeader(ir::BasicBlock* bb) { +BasicBlock* CFG::SplitLoopHeader(BasicBlock* bb) { assert(bb->GetLoopMergeInst() && "Expecting bb to be the header of a loop."); Function* fn = bb->GetParent(); - IRContext* context = fn->context(); + IRContext* context = module_->context(); // Find the insertion point for the new bb. Function::iterator header_it = std::find_if( @@ -156,7 +175,7 @@ BasicBlock* CFG::SplitLoopHeader(ir::BasicBlock* bb) { const std::vector& pred = preds(bb->id()); // Find the back edge - ir::BasicBlock* latch_block = nullptr; + BasicBlock* latch_block = nullptr; Function::iterator latch_block_iter = header_it; while (++latch_block_iter != fn->end()) { // If blocks are in the proper order, then the only branch that appears @@ -178,13 +197,13 @@ BasicBlock* CFG::SplitLoopHeader(ir::BasicBlock* bb) { ++iter; } - std::unique_ptr newBlock( + std::unique_ptr newBlock( bb->SplitBasicBlock(context, context->TakeNextId(), iter)); // Insert the new bb in the correct position auto insert_pos = header_it; ++insert_pos; - ir::BasicBlock* new_header = &*insert_pos.InsertBefore(std::move(newBlock)); + BasicBlock* new_header = &*insert_pos.InsertBefore(std::move(newBlock)); new_header->SetParent(fn); uint32_t new_header_id = new_header->id(); context->AnalyzeDefUse(new_header->GetLabelInst()); @@ -194,7 +213,7 @@ BasicBlock* CFG::SplitLoopHeader(ir::BasicBlock* bb) { // Update bb mappings. context->set_instr_block(new_header->GetLabelInst(), new_header); - new_header->ForEachInst([new_header, context](ir::Instruction* inst) { + new_header->ForEachInst([new_header, context](Instruction* inst) { context->set_instr_block(inst, new_header); }); @@ -220,13 +239,11 @@ BasicBlock* CFG::SplitLoopHeader(ir::BasicBlock* bb) { // Create a phi instruction if and only if the preheader_phi_ops has more // than one pair. if (preheader_phi_ops.size() > 2) { - opt::InstructionBuilder builder( + InstructionBuilder builder( context, &*bb->begin(), - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping); + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); - ir::Instruction* new_phi = - builder.AddPhi(phi->type_id(), preheader_phi_ops); + Instruction* new_phi = builder.AddPhi(phi->type_id(), preheader_phi_ops); // Add the OpPhi to the header bb. header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {new_phi->result_id()}}); @@ -239,7 +256,7 @@ BasicBlock* CFG::SplitLoopHeader(ir::BasicBlock* bb) { } phi->RemoveFromList(); - std::unique_ptr phi_owner(phi); + std::unique_ptr phi_owner(phi); phi->SetInOperands(std::move(header_phi_ops)); new_header->begin()->InsertBefore(std::move(phi_owner)); context->set_instr_block(phi, new_header); @@ -247,14 +264,13 @@ BasicBlock* CFG::SplitLoopHeader(ir::BasicBlock* bb) { }); // Add a branch to the new header. - opt::InstructionBuilder branch_builder( + InstructionBuilder branch_builder( context, bb, - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping); - bb->AddInstruction(MakeUnique( - context, SpvOpBranch, 0, 0, - std::initializer_list{ - {SPV_OPERAND_TYPE_ID, {new_header->id()}}})); + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + bb->AddInstruction( + MakeUnique(context, SpvOpBranch, 0, 0, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {new_header->id()}}})); context->AnalyzeUses(bb->terminator()); context->set_instr_block(bb->terminator(), bb); label2preds_[new_header->id()].push_back(bb->id()); @@ -265,7 +281,7 @@ BasicBlock* CFG::SplitLoopHeader(ir::BasicBlock* bb) { *id = new_header_id; } }); - ir::Instruction* latch_branch = latch_block->terminator(); + Instruction* latch_branch = latch_block->terminator(); context->AnalyzeUses(latch_branch); label2preds_[new_header->id()].push_back(latch_block->id()); @@ -276,7 +292,7 @@ BasicBlock* CFG::SplitLoopHeader(ir::BasicBlock* bb) { block_preds.erase(latch_pos); // Update the loop descriptors - if (context->AreAnalysesValid(ir::IRContext::kAnalysisLoopAnalysis)) { + if (context->AreAnalysesValid(IRContext::kAnalysisLoopAnalysis)) { LoopDescriptor* loop_desc = context->GetLoopDescriptor(bb->GetParent()); Loop* loop = (*loop_desc)[bb->id()]; @@ -298,13 +314,5 @@ BasicBlock* CFG::SplitLoopHeader(ir::BasicBlock* bb) { return new_header; } -unordered_set CFG::FindReachableBlocks(BasicBlock* start) { - std::unordered_set reachable_blocks; - ForEachBlockInReversePostOrder(start, [&reachable_blocks](BasicBlock* bb) { - reachable_blocks.insert(bb); - }); - return reachable_blocks; -} - -} // namespace ir +} // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/cfg.h b/3rdparty/spirv-tools/source/opt/cfg.h index ffff7e176..375d09c5c 100644 --- a/3rdparty/spirv-tools/source/opt/cfg.h +++ b/3rdparty/spirv-tools/source/opt/cfg.h @@ -12,28 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_CFG_H_ -#define LIBSPIRV_OPT_CFG_H_ - -#include "basic_block.h" +#ifndef SOURCE_OPT_CFG_H_ +#define SOURCE_OPT_CFG_H_ #include #include #include #include +#include + +#include "source/opt/basic_block.h" namespace spvtools { -namespace ir { +namespace opt { class CFG { public: - CFG(ir::Module* module); - - // Return the module described by this CFG. - ir::Module* get_module() const { return module_; } + explicit CFG(Module* module); // Return the list of predecesors for basic block with label |blkid|. - // TODO(dnovillo): Move this to ir::BasicBlock. + // TODO(dnovillo): Move this to BasicBlock. const std::vector& preds(uint32_t blk_id) const { assert(label2preds_.count(blk_id)); return label2preds_.at(blk_id); @@ -41,26 +39,22 @@ class CFG { // Return a pointer to the basic block instance corresponding to the label // |blk_id|. - ir::BasicBlock* block(uint32_t blk_id) const { return id2block_.at(blk_id); } + BasicBlock* block(uint32_t blk_id) const { return id2block_.at(blk_id); } // Return the pseudo entry and exit blocks. - const ir::BasicBlock* pseudo_entry_block() const { - return &pseudo_entry_block_; - } - ir::BasicBlock* pseudo_entry_block() { return &pseudo_entry_block_; } + const BasicBlock* pseudo_entry_block() const { return &pseudo_entry_block_; } + BasicBlock* pseudo_entry_block() { return &pseudo_entry_block_; } - const ir::BasicBlock* pseudo_exit_block() const { - return &pseudo_exit_block_; - } - ir::BasicBlock* pseudo_exit_block() { return &pseudo_exit_block_; } + const BasicBlock* pseudo_exit_block() const { return &pseudo_exit_block_; } + BasicBlock* pseudo_exit_block() { return &pseudo_exit_block_; } // Return true if |block_ptr| is the pseudo-entry block. - bool IsPseudoEntryBlock(ir::BasicBlock* block_ptr) const { + bool IsPseudoEntryBlock(BasicBlock* block_ptr) const { return block_ptr == &pseudo_entry_block_; } // Return true if |block_ptr| is the pseudo-exit block. - bool IsPseudoExitBlock(ir::BasicBlock* block_ptr) const { + bool IsPseudoExitBlock(BasicBlock* block_ptr) const { return block_ptr == &pseudo_exit_block_; } @@ -68,8 +62,14 @@ class CFG { // This order has the property that dominators come before all blocks they // dominate and merge blocks come after all blocks that are in the control // constructs of their header. - void ComputeStructuredOrder(ir::Function* func, ir::BasicBlock* root, - std::list* order); + void ComputeStructuredOrder(Function* func, BasicBlock* root, + std::list* order); + + // Applies |f| to the basic block in post order starting with |bb|. + // Note that basic blocks that cannot be reached from |bb| node will not be + // processed. + void ForEachBlockInPostOrder(BasicBlock* bb, + const std::function& f); // Applies |f| to the basic block in reverse post order starting with |bb|. // Note that basic blocks that cannot be reached from |bb| node will not be @@ -79,14 +79,14 @@ class CFG { // Registers |blk| as a basic block in the cfg, this also updates the // predecessor lists of each successor of |blk|. - void RegisterBlock(ir::BasicBlock* blk) { + void RegisterBlock(BasicBlock* blk) { uint32_t blk_id = blk->id(); id2block_[blk_id] = blk; AddEdges(blk); } // Removes from the CFG any mapping for the basic block id |blk_id|. - void ForgetBlock(const ir::BasicBlock* blk) { + void ForgetBlock(const BasicBlock* blk) { id2block_.erase(blk->id()); label2preds_.erase(blk->id()); RemoveSuccessorEdges(blk); @@ -101,7 +101,7 @@ class CFG { } // Registers |blk| to all of its successors. - void AddEdges(ir::BasicBlock* blk); + void AddEdges(BasicBlock* blk); // Registers the basic block id |pred_blk_id| as being a predecessor of the // basic block id |succ_blk_id|. @@ -114,7 +114,7 @@ class CFG { void RemoveNonExistingEdges(uint32_t blk_id); // Remove all edges that leave |bb|. - void RemoveSuccessorEdges(const ir::BasicBlock* bb) { + void RemoveSuccessorEdges(const BasicBlock* bb) { bb->ForEachSuccessorLabel( [bb, this](uint32_t succ_id) { RemoveEdge(bb->id(), succ_id); }); } @@ -124,13 +124,9 @@ class CFG { // is a new block that will be the new loop header. // // Returns a pointer to the new loop header. - BasicBlock* SplitLoopHeader(ir::BasicBlock* bb); - - std::unordered_set FindReachableBlocks(BasicBlock* start); + BasicBlock* SplitLoopHeader(BasicBlock* bb); private: - using cbb_ptr = const ir::BasicBlock*; - // Compute structured successors for function |func|. A block's structured // successors are the blocks it branches to together with its declared merge // block and continue block if it has them. When order matters, the merge @@ -138,7 +134,7 @@ class CFG { // first search in the presence of early returns and kills. If the successor // vector contain duplicates of the merge or continue blocks, they are safely // ignored by DFS. - void ComputeStructuredSuccessors(ir::Function* func); + void ComputeStructuredSuccessors(Function* func); // Computes the post-order traversal of the cfg starting at |bb| skipping // nodes in |seen|. The order of the traversal is appended to |order|, and @@ -148,28 +144,28 @@ class CFG { std::unordered_set* seen); // Module for this CFG. - ir::Module* module_; + Module* module_; // Map from block to its structured successor blocks. See // ComputeStructuredSuccessors() for definition. - std::unordered_map> + std::unordered_map> block2structured_succs_; // Extra block whose successors are all blocks with no predecessors // in function. - ir::BasicBlock pseudo_entry_block_; + BasicBlock pseudo_entry_block_; // Augmented CFG Exit Block. - ir::BasicBlock pseudo_exit_block_; + BasicBlock pseudo_exit_block_; // Map from block's label id to its predecessor blocks ids std::unordered_map> label2preds_; // Map from block's label id to block. - std::unordered_map id2block_; + std::unordered_map id2block_; }; -} // namespace ir +} // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_CFG_H_ +#endif // SOURCE_OPT_CFG_H_ diff --git a/3rdparty/spirv-tools/source/opt/cfg_cleanup_pass.cpp b/3rdparty/spirv-tools/source/opt/cfg_cleanup_pass.cpp index 2565225eb..2d548462b 100644 --- a/3rdparty/spirv-tools/source/opt/cfg_cleanup_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/cfg_cleanup_pass.cpp @@ -19,21 +19,17 @@ #include #include -#include "cfg_cleanup_pass.h" +#include "source/opt/cfg_cleanup_pass.h" -#include "function.h" -#include "module.h" +#include "source/opt/function.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { -void CFGCleanupPass::Initialize(ir::IRContext* c) { InitializeProcessing(c); } - -Pass::Status CFGCleanupPass::Process(ir::IRContext* c) { - Initialize(c); - +Pass::Status CFGCleanupPass::Process() { // Process all entry point functions. - ProcessFunction pfn = [this](ir::Function* fp) { return CFGCleanup(fp); }; + ProcessFunction pfn = [this](Function* fp) { return CFGCleanup(fp); }; bool modified = ProcessReachableCallTree(pfn, context()); return modified ? Pass::Status::SuccessWithChange : Pass::Status::SuccessWithoutChange; diff --git a/3rdparty/spirv-tools/source/opt/cfg_cleanup_pass.h b/3rdparty/spirv-tools/source/opt/cfg_cleanup_pass.h index 116e11d1b..afbc67c09 100644 --- a/3rdparty/spirv-tools/source/opt/cfg_cleanup_pass.h +++ b/3rdparty/spirv-tools/source/opt/cfg_cleanup_pass.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_CFG_CLEANUP_PASS_H_ -#define LIBSPIRV_OPT_CFG_CLEANUP_PASS_H_ +#ifndef SOURCE_OPT_CFG_CLEANUP_PASS_H_ +#define SOURCE_OPT_CFG_CLEANUP_PASS_H_ -#include "function.h" -#include "mem_pass.h" -#include "module.h" +#include "source/opt/function.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { @@ -25,19 +25,16 @@ namespace opt { class CFGCleanupPass : public MemPass { public: CFGCleanupPass() = default; + const char* name() const override { return "cfg-cleanup"; } - Status Process(ir::IRContext* c) override; + Status Process() override; - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse; + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse; } - - private: - // Initialize the pass. - void Initialize(ir::IRContext* c); }; } // namespace opt } // namespace spvtools -#endif +#endif // SOURCE_OPT_CFG_CLEANUP_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/combine_access_chains.cpp b/3rdparty/spirv-tools/source/opt/combine_access_chains.cpp new file mode 100644 index 000000000..facfc24b6 --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/combine_access_chains.cpp @@ -0,0 +1,290 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/combine_access_chains.h" + +#include + +#include "source/opt/constants.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +Pass::Status CombineAccessChains::Process() { + bool modified = false; + + for (auto& function : *get_module()) { + modified |= ProcessFunction(function); + } + + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool CombineAccessChains::ProcessFunction(Function& function) { + bool modified = false; + + cfg()->ForEachBlockInReversePostOrder( + function.entry().get(), [&modified, this](BasicBlock* block) { + block->ForEachInst([&modified, this](Instruction* inst) { + switch (inst->opcode()) { + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpPtrAccessChain: + case SpvOpInBoundsPtrAccessChain: + modified |= CombineAccessChain(inst); + break; + default: + break; + } + }); + }); + + return modified; +} + +uint32_t CombineAccessChains::GetConstantValue( + const analysis::Constant* constant_inst) { + if (constant_inst->type()->AsInteger()->width() <= 32) { + if (constant_inst->type()->AsInteger()->IsSigned()) { + return static_cast(constant_inst->GetS32()); + } else { + return constant_inst->GetU32(); + } + } else { + assert(false); + return 0u; + } +} + +uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) { + uint32_t array_stride = 0; + context()->get_decoration_mgr()->WhileEachDecoration( + inst->type_id(), SpvDecorationArrayStride, + [&array_stride](const Instruction& decoration) { + assert(decoration.opcode() != SpvOpDecorateId); + if (decoration.opcode() == SpvOpDecorate) { + array_stride = decoration.GetSingleWordInOperand(1); + } else { + array_stride = decoration.GetSingleWordInOperand(2); + } + return false; + }); + return array_stride; +} + +const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + + Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + const analysis::Type* type = type_mgr->GetType(base_ptr->type_id()); + assert(type->AsPointer()); + type = type->AsPointer()->pointee_type(); + std::vector element_indices; + uint32_t starting_index = 1; + if (IsPtrAccessChain(inst->opcode())) { + // Skip the first index of OpPtrAccessChain as it does not affect type + // resolution. + starting_index = 2; + } + for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { + Instruction* index_inst = + def_use_mgr->GetDef(inst->GetSingleWordInOperand(i)); + const analysis::Constant* index_constant = + context()->get_constant_mgr()->GetConstantFromInst(index_inst); + if (index_constant) { + uint32_t index_value = GetConstantValue(index_constant); + element_indices.push_back(index_value); + } else { + // This index must not matter to resolve the type in valid SPIR-V. + element_indices.push_back(0); + } + } + type = type_mgr->GetMemberType(type, element_indices); + return type; +} + +bool CombineAccessChains::CombineIndices(Instruction* ptr_input, + Instruction* inst, + std::vector* new_operands) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + analysis::ConstantManager* constant_mgr = context()->get_constant_mgr(); + + Instruction* last_index_inst = def_use_mgr->GetDef( + ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1)); + const analysis::Constant* last_index_constant = + constant_mgr->GetConstantFromInst(last_index_inst); + + Instruction* element_inst = + def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); + const analysis::Constant* element_constant = + constant_mgr->GetConstantFromInst(element_inst); + + // Combine the last index of the AccessChain (|ptr_inst|) with the element + // operand of the PtrAccessChain (|inst|). + const bool combining_element_operands = + IsPtrAccessChain(inst->opcode()) && + IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2; + uint32_t new_value_id = 0; + const analysis::Type* type = GetIndexedType(ptr_input); + if (last_index_constant && element_constant) { + // Combine the constants. + uint32_t new_value = GetConstantValue(last_index_constant) + + GetConstantValue(element_constant); + const analysis::Constant* new_value_constant = + constant_mgr->GetConstant(last_index_constant->type(), {new_value}); + Instruction* new_value_inst = + constant_mgr->GetDefiningInstruction(new_value_constant); + new_value_id = new_value_inst->result_id(); + } else if (!type->AsStruct() || combining_element_operands) { + // Generate an addition of the two indices. + InstructionBuilder builder( + context(), inst, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + Instruction* addition = builder.AddIAdd(last_index_inst->type_id(), + last_index_inst->result_id(), + element_inst->result_id()); + new_value_id = addition->result_id(); + } else { + // Indexing into structs must be constant, so bail out here. + return false; + } + new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}}); + return true; +} + +bool CombineAccessChains::CreateNewInputOperands( + Instruction* ptr_input, Instruction* inst, + std::vector* new_operands) { + // Start by copying all the input operands of the feeder access chain. + for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) { + new_operands->push_back(ptr_input->GetInOperand(i)); + } + + // Deal with the last index of the feeder access chain. + if (IsPtrAccessChain(inst->opcode())) { + // The last index of the feeder should be combined with the element operand + // of |inst|. + if (!CombineIndices(ptr_input, inst, new_operands)) return false; + } else { + // The indices aren't being combined so now add the last index operand of + // |ptr_input|. + new_operands->push_back( + ptr_input->GetInOperand(ptr_input->NumInOperands() - 1)); + } + + // Copy the remaining index operands. + uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1; + for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { + new_operands->push_back(inst->GetInOperand(i)); + } + + return true; +} + +bool CombineAccessChains::CombineAccessChain(Instruction* inst) { + assert((inst->opcode() == SpvOpPtrAccessChain || + inst->opcode() == SpvOpAccessChain || + inst->opcode() == SpvOpInBoundsAccessChain || + inst->opcode() == SpvOpInBoundsPtrAccessChain) && + "Wrong opcode. Expected an access chain."); + + Instruction* ptr_input = + context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0)); + if (ptr_input->opcode() != SpvOpAccessChain && + ptr_input->opcode() != SpvOpInBoundsAccessChain && + ptr_input->opcode() != SpvOpPtrAccessChain && + ptr_input->opcode() != SpvOpInBoundsPtrAccessChain) { + return false; + } + + if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false; + + // Handles the following cases: + // 1. |ptr_input| is an index-less access chain. Replace the pointer + // in |inst| with |ptr_input|'s pointer. + // 2. |inst| is a index-less access chain. Change |inst| to an + // OpCopyObject. + // 3. |inst| is not a pointer access chain. + // |inst|'s indices are appended to |ptr_input|'s indices. + // 4. |ptr_input| is not pointer access chain. + // |inst| is a pointer access chain. + // |inst|'s element operand is combined with the last index in + // |ptr_input| to form a new operand. + // 5. |ptr_input| is a pointer access chain. + // Like the above scenario, |inst|'s element operand is combined + // with |ptr_input|'s last index. This results is either a + // combined element operand or combined regular index. + + // TODO(alan-baker): Support this properly. Requires analyzing the + // size/alignment of the type and converting the stride into an element + // index. + uint32_t array_stride = GetArrayStride(ptr_input); + if (array_stride != 0) return false; + + if (ptr_input->NumInOperands() == 1) { + // The input is effectively a no-op. + inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)}); + context()->AnalyzeUses(inst); + } else if (inst->NumInOperands() == 1) { + // |inst| is a no-op, change it to a copy. Instruction simplification will + // clean it up. + inst->SetOpcode(SpvOpCopyObject); + } else { + std::vector new_operands; + if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false; + + // Update the instruction. + inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode())); + inst->SetInOperands(std::move(new_operands)); + context()->AnalyzeUses(inst); + } + return true; +} + +SpvOp CombineAccessChains::UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode) { + auto IsInBounds = [](SpvOp opcode) { + return opcode == SpvOpInBoundsPtrAccessChain || + opcode == SpvOpInBoundsAccessChain; + }; + + if (input_opcode == SpvOpInBoundsPtrAccessChain) { + if (!IsInBounds(base_opcode)) return SpvOpPtrAccessChain; + } else if (input_opcode == SpvOpInBoundsAccessChain) { + if (!IsInBounds(base_opcode)) return SpvOpAccessChain; + } + + return input_opcode; +} + +bool CombineAccessChains::IsPtrAccessChain(SpvOp opcode) { + return opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain; +} + +bool CombineAccessChains::Has64BitIndices(Instruction* inst) { + for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { + Instruction* index_inst = + context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i)); + const analysis::Type* index_type = + context()->get_type_mgr()->GetType(index_inst->type_id()); + if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32) + return true; + } + return false; +} + +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/combine_access_chains.h b/3rdparty/spirv-tools/source/opt/combine_access_chains.h new file mode 100644 index 000000000..75885dada --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/combine_access_chains.h @@ -0,0 +1,82 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_COMBINE_ACCESS_CHAINS_H_ +#define SOURCE_OPT_COMBINE_ACCESS_CHAINS_H_ + +#include + +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class CombineAccessChains : public Pass { + public: + const char* name() const override { return "combine-access-chains"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap; + } + + private: + // Combine access chains in |function|. Blocks are processed in reverse + // post-order. Returns true if the function is modified. + bool ProcessFunction(Function& function); + + // Combines an access chain (normal, in bounds or pointer) |inst| if its base + // pointer is another access chain. Returns true if the access chain was + // modified. + bool CombineAccessChain(Instruction* inst); + + // Returns the value of |constant_inst| as a uint32_t. + uint32_t GetConstantValue(const analysis::Constant* constant_inst); + + // Returns the array stride of |inst|'s type. + uint32_t GetArrayStride(const Instruction* inst); + + // Returns the type by resolving the index operands |inst|. |inst| must be an + // access chain instruction. + const analysis::Type* GetIndexedType(Instruction* inst); + + // Populates |new_operands| with the operands for the combined access chain. + // Returns false if the access chains cannot be combined. + bool CreateNewInputOperands(Instruction* ptr_input, Instruction* inst, + std::vector* new_operands); + + // Combines the last index of |ptr_input| with the element operand of |inst|. + // Adds the combined operand to |new_operands|. + bool CombineIndices(Instruction* ptr_input, Instruction* inst, + std::vector* new_operands); + + // Returns the opcode to use for the combined access chain. + SpvOp UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode); + + // Returns true if |opcode| is a pointer access chain. + bool IsPtrAccessChain(SpvOp opcode); + + // Returns true if |inst| (an access chain) has 64-bit indices. + bool Has64BitIndices(Instruction* inst); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_COMBINE_ACCESS_CHAINS_H_ diff --git a/3rdparty/spirv-tools/source/opt/common_uniform_elim_pass.cpp b/3rdparty/spirv-tools/source/opt/common_uniform_elim_pass.cpp index b060ea9c5..e6426a555 100644 --- a/3rdparty/spirv-tools/source/opt/common_uniform_elim_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/common_uniform_elim_pass.cpp @@ -14,9 +14,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "common_uniform_elim_pass.h" -#include "cfa.h" -#include "ir_context.h" +#include "source/opt/common_uniform_elim_pass.h" +#include "source/cfa.h" +#include "source/opt/ir_context.h" namespace spvtools { namespace opt { @@ -41,7 +41,7 @@ bool CommonUniformElimPass::IsNonPtrAccessChain(const SpvOp opcode) const { } bool CommonUniformElimPass::IsSamplerOrImageType( - const ir::Instruction* typeInst) const { + const Instruction* typeInst) const { switch (typeInst->opcode()) { case SpvOpTypeSampler: case SpvOpTypeImage: @@ -53,7 +53,7 @@ bool CommonUniformElimPass::IsSamplerOrImageType( if (typeInst->opcode() != SpvOpTypeStruct) return false; // Return true if any member is a sampler or image return !typeInst->WhileEachInId([this](const uint32_t* tid) { - const ir::Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid); + const Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid); if (IsSamplerOrImageType(compTypeInst)) { return false; } @@ -62,28 +62,27 @@ bool CommonUniformElimPass::IsSamplerOrImageType( } bool CommonUniformElimPass::IsSamplerOrImageVar(uint32_t varId) const { - const ir::Instruction* varInst = get_def_use_mgr()->GetDef(varId); + const Instruction* varInst = get_def_use_mgr()->GetDef(varId); assert(varInst->opcode() == SpvOpVariable); const uint32_t varTypeId = varInst->type_id(); - const ir::Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); + const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); const uint32_t varPteTypeId = varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx); - ir::Instruction* varPteTypeInst = get_def_use_mgr()->GetDef(varPteTypeId); + Instruction* varPteTypeInst = get_def_use_mgr()->GetDef(varPteTypeId); return IsSamplerOrImageType(varPteTypeInst); } -ir::Instruction* CommonUniformElimPass::GetPtr(ir::Instruction* ip, - uint32_t* objId) { +Instruction* CommonUniformElimPass::GetPtr(Instruction* ip, uint32_t* objId) { const SpvOp op = ip->opcode(); assert(op == SpvOpStore || op == SpvOpLoad); *objId = ip->GetSingleWordInOperand(op == SpvOpStore ? kStorePtrIdInIdx : kLoadPtrIdInIdx); - ir::Instruction* ptrInst = get_def_use_mgr()->GetDef(*objId); + Instruction* ptrInst = get_def_use_mgr()->GetDef(*objId); while (ptrInst->opcode() == SpvOpCopyObject) { *objId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx); ptrInst = get_def_use_mgr()->GetDef(*objId); } - ir::Instruction* objInst = ptrInst; + Instruction* objInst = ptrInst; while (objInst->opcode() != SpvOpVariable && objInst->opcode() != SpvOpFunctionParameter) { if (IsNonPtrAccessChain(objInst->opcode())) { @@ -100,22 +99,21 @@ ir::Instruction* CommonUniformElimPass::GetPtr(ir::Instruction* ip, bool CommonUniformElimPass::IsVolatileStruct(uint32_t type_id) { assert(get_def_use_mgr()->GetDef(type_id)->opcode() == SpvOpTypeStruct); return !get_decoration_mgr()->WhileEachDecoration( - type_id, SpvDecorationVolatile, - [](const ir::Instruction&) { return false; }); + type_id, SpvDecorationVolatile, [](const Instruction&) { return false; }); } bool CommonUniformElimPass::IsAccessChainToVolatileStructType( - const ir::Instruction& AccessChainInst) { + const Instruction& AccessChainInst) { assert(AccessChainInst.opcode() == SpvOpAccessChain); uint32_t ptr_id = AccessChainInst.GetSingleWordInOperand(0); - const ir::Instruction* ptr_inst = get_def_use_mgr()->GetDef(ptr_id); + const Instruction* ptr_inst = get_def_use_mgr()->GetDef(ptr_id); uint32_t pointee_type_id = GetPointeeTypeId(ptr_inst); const uint32_t num_operands = AccessChainInst.NumOperands(); // walk the type tree: for (uint32_t idx = 3; idx < num_operands; ++idx) { - ir::Instruction* pointee_type = get_def_use_mgr()->GetDef(pointee_type_id); + Instruction* pointee_type = get_def_use_mgr()->GetDef(pointee_type_id); switch (pointee_type->opcode()) { case SpvOpTypeMatrix: @@ -130,8 +128,7 @@ bool CommonUniformElimPass::IsAccessChainToVolatileStructType( if (idx < num_operands - 1) { const uint32_t index_id = AccessChainInst.GetSingleWordOperand(idx); - const ir::Instruction* index_inst = - get_def_use_mgr()->GetDef(index_id); + const Instruction* index_inst = get_def_use_mgr()->GetDef(index_id); uint32_t index_value = index_inst->GetSingleWordOperand( 2); // TODO: replace with GetUintValueFromConstant() pointee_type_id = pointee_type->GetSingleWordInOperand(index_value); @@ -144,7 +141,7 @@ bool CommonUniformElimPass::IsAccessChainToVolatileStructType( return false; } -bool CommonUniformElimPass::IsVolatileLoad(const ir::Instruction& loadInst) { +bool CommonUniformElimPass::IsVolatileLoad(const Instruction& loadInst) { assert(loadInst.opcode() == SpvOpLoad); // Check if this Load instruction has Volatile Memory Access flag if (loadInst.NumOperands() == 4) { @@ -161,11 +158,11 @@ bool CommonUniformElimPass::IsVolatileLoad(const ir::Instruction& loadInst) { } bool CommonUniformElimPass::IsUniformVar(uint32_t varId) { - const ir::Instruction* varInst = + const Instruction* varInst = get_def_use_mgr()->id_to_defs().find(varId)->second; if (varInst->opcode() != SpvOpVariable) return false; const uint32_t varTypeId = varInst->type_id(); - const ir::Instruction* varTypeInst = + const Instruction* varTypeInst = get_def_use_mgr()->id_to_defs().find(varTypeId)->second; return varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) == SpvStorageClassUniform || @@ -174,21 +171,21 @@ bool CommonUniformElimPass::IsUniformVar(uint32_t varId) { } bool CommonUniformElimPass::HasUnsupportedDecorates(uint32_t id) const { - return !get_def_use_mgr()->WhileEachUser(id, [this](ir::Instruction* user) { + return !get_def_use_mgr()->WhileEachUser(id, [this](Instruction* user) { if (IsNonTypeDecorate(user->opcode())) return false; return true; }); } bool CommonUniformElimPass::HasOnlyNamesAndDecorates(uint32_t id) const { - return get_def_use_mgr()->WhileEachUser(id, [this](ir::Instruction* user) { + return get_def_use_mgr()->WhileEachUser(id, [this](Instruction* user) { SpvOp op = user->opcode(); if (op != SpvOpName && !IsNonTypeDecorate(op)) return false; return true; }); } -void CommonUniformElimPass::DeleteIfUseless(ir::Instruction* inst) { +void CommonUniformElimPass::DeleteIfUseless(Instruction* inst) { const uint32_t resId = inst->result_id(); assert(resId != 0); if (HasOnlyNamesAndDecorates(resId)) { @@ -196,34 +193,33 @@ void CommonUniformElimPass::DeleteIfUseless(ir::Instruction* inst) { } } -ir::Instruction* CommonUniformElimPass::ReplaceAndDeleteLoad( - ir::Instruction* loadInst, uint32_t replId, ir::Instruction* ptrInst) { +Instruction* CommonUniformElimPass::ReplaceAndDeleteLoad(Instruction* loadInst, + uint32_t replId, + Instruction* ptrInst) { const uint32_t loadId = loadInst->result_id(); context()->KillNamesAndDecorates(loadId); (void)context()->ReplaceAllUsesWith(loadId, replId); // remove load instruction - ir::Instruction* next_instruction = context()->KillInst(loadInst); + Instruction* next_instruction = context()->KillInst(loadInst); // if access chain, see if it can be removed as well if (IsNonPtrAccessChain(ptrInst->opcode())) DeleteIfUseless(ptrInst); return next_instruction; } void CommonUniformElimPass::GenACLoadRepl( - const ir::Instruction* ptrInst, - std::vector>* newInsts, - uint32_t* resultId) { + const Instruction* ptrInst, + std::vector>* newInsts, uint32_t* resultId) { // Build and append Load const uint32_t ldResultId = TakeNextId(); const uint32_t varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx); - const ir::Instruction* varInst = get_def_use_mgr()->GetDef(varId); + const Instruction* varInst = get_def_use_mgr()->GetDef(varId); assert(varInst->opcode() == SpvOpVariable); const uint32_t varPteTypeId = GetPointeeTypeId(varInst); - std::vector load_in_operands; - load_in_operands.push_back( - ir::Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, - std::initializer_list{varId})); - std::unique_ptr newLoad(new ir::Instruction( + std::vector load_in_operands; + load_in_operands.push_back(Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, + std::initializer_list{varId})); + std::unique_ptr newLoad(new Instruction( context(), SpvOpLoad, varPteTypeId, ldResultId, load_in_operands)); get_def_use_mgr()->AnalyzeInstDefUse(&*newLoad); newInsts->emplace_back(std::move(newLoad)); @@ -231,34 +227,33 @@ void CommonUniformElimPass::GenACLoadRepl( // Build and append Extract const uint32_t extResultId = TakeNextId(); const uint32_t ptrPteTypeId = GetPointeeTypeId(ptrInst); - std::vector ext_in_opnds; - ext_in_opnds.push_back( - ir::Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, - std::initializer_list{ldResultId})); + std::vector ext_in_opnds; + ext_in_opnds.push_back(Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, + std::initializer_list{ldResultId})); uint32_t iidIdx = 0; ptrInst->ForEachInId([&iidIdx, &ext_in_opnds, this](const uint32_t* iid) { if (iidIdx > 0) { - const ir::Instruction* cInst = get_def_use_mgr()->GetDef(*iid); + const Instruction* cInst = get_def_use_mgr()->GetDef(*iid); uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx); ext_in_opnds.push_back( - ir::Operand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, - std::initializer_list{val})); + Operand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, + std::initializer_list{val})); } ++iidIdx; }); - std::unique_ptr newExt( - new ir::Instruction(context(), SpvOpCompositeExtract, ptrPteTypeId, - extResultId, ext_in_opnds)); + std::unique_ptr newExt( + new Instruction(context(), SpvOpCompositeExtract, ptrPteTypeId, + extResultId, ext_in_opnds)); get_def_use_mgr()->AnalyzeInstDefUse(&*newExt); newInsts->emplace_back(std::move(newExt)); *resultId = extResultId; } -bool CommonUniformElimPass::IsConstantIndexAccessChain(ir::Instruction* acp) { +bool CommonUniformElimPass::IsConstantIndexAccessChain(Instruction* acp) { uint32_t inIdx = 0; return acp->WhileEachInId([&inIdx, this](uint32_t* tid) { if (inIdx > 0) { - ir::Instruction* opInst = get_def_use_mgr()->GetDef(*tid); + Instruction* opInst = get_def_use_mgr()->GetDef(*tid); if (opInst->opcode() != SpvOpConstant) return false; } ++inIdx; @@ -266,13 +261,13 @@ bool CommonUniformElimPass::IsConstantIndexAccessChain(ir::Instruction* acp) { }); } -bool CommonUniformElimPass::UniformAccessChainConvert(ir::Function* func) { +bool CommonUniformElimPass::UniformAccessChainConvert(Function* func) { bool modified = false; for (auto bi = func->begin(); bi != func->end(); ++bi) { - for (ir::Instruction* inst = &*bi->begin(); inst; inst = inst->NextNode()) { + for (Instruction* inst = &*bi->begin(); inst; inst = inst->NextNode()) { if (inst->opcode() != SpvOpLoad) continue; uint32_t varId; - ir::Instruction* ptrInst = GetPtr(inst, &varId); + Instruction* ptrInst = GetPtr(inst, &varId); if (!IsNonPtrAccessChain(ptrInst->opcode())) continue; // Do not convert nested access chains if (ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx) != varId) @@ -283,18 +278,18 @@ bool CommonUniformElimPass::UniformAccessChainConvert(ir::Function* func) { if (HasUnsupportedDecorates(ptrInst->result_id())) continue; if (IsVolatileLoad(*inst)) continue; if (IsAccessChainToVolatileStructType(*ptrInst)) continue; - std::vector> newInsts; + std::vector> newInsts; uint32_t replId; GenACLoadRepl(ptrInst, &newInsts, &replId); inst = ReplaceAndDeleteLoad(inst, replId, ptrInst); inst = inst->InsertBefore(std::move(newInsts)); modified = true; - }; + } } return modified; } -void CommonUniformElimPass::ComputeStructuredSuccessors(ir::Function* func) { +void CommonUniformElimPass::ComputeStructuredSuccessors(Function* func) { block2structured_succs_.clear(); for (auto& blk : *func) { // If header, make merge block first successor. @@ -315,32 +310,32 @@ void CommonUniformElimPass::ComputeStructuredSuccessors(ir::Function* func) { } void CommonUniformElimPass::ComputeStructuredOrder( - ir::Function* func, std::list* order) { + Function* func, std::list* order) { // Compute structured successors and do DFS ComputeStructuredSuccessors(func); auto ignore_block = [](cbb_ptr) {}; auto ignore_edge = [](cbb_ptr, cbb_ptr) {}; - auto get_structured_successors = [this](const ir::BasicBlock* block) { + auto get_structured_successors = [this](const BasicBlock* block) { return &(block2structured_succs_[block]); }; // TODO(greg-lunarg): Get rid of const_cast by making moving const // out of the cfa.h prototypes and into the invoking code. auto post_order = [&](cbb_ptr b) { - order->push_front(const_cast(b)); + order->push_front(const_cast(b)); }; order->clear(); - spvtools::CFA::DepthFirstTraversal( - &*func->begin(), get_structured_successors, ignore_block, post_order, - ignore_edge); + CFA::DepthFirstTraversal(&*func->begin(), + get_structured_successors, ignore_block, + post_order, ignore_edge); } -bool CommonUniformElimPass::CommonUniformLoadElimination(ir::Function* func) { +bool CommonUniformElimPass::CommonUniformLoadElimination(Function* func) { // Process all blocks in structured order. This is just one way (the // simplest?) to keep track of the most recent block outside of control // flow, used to copy common instructions, guaranteed to dominate all // following load sites. - std::list structuredOrder; + std::list structuredOrder; ComputeStructuredOrder(func, &structuredOrder); uniform2load_id_.clear(); bool modified = false; @@ -349,19 +344,25 @@ bool CommonUniformElimPass::CommonUniformLoadElimination(ir::Function* func) { while (insertItr->opcode() == SpvOpVariable || insertItr->opcode() == SpvOpNop) ++insertItr; + // Update insertItr until it will not be removed. Without this code, + // ReplaceAndDeleteLoad() can set |insertItr| as a dangling pointer. + while (IsUniformLoadToBeRemoved(&*insertItr)) ++insertItr; uint32_t mergeBlockId = 0; for (auto bi = structuredOrder.begin(); bi != structuredOrder.end(); ++bi) { - ir::BasicBlock* bp = *bi; + BasicBlock* bp = *bi; // Check if we are exiting outermost control construct. If so, remember // new load insertion point. Trying to keep register pressure down. if (mergeBlockId == bp->id()) { mergeBlockId = 0; insertItr = bp->begin(); + // Update insertItr until it will not be removed. Without this code, + // ReplaceAndDeleteLoad() can set |insertItr| as a dangling pointer. + while (IsUniformLoadToBeRemoved(&*insertItr)) ++insertItr; } - for (ir::Instruction* inst = &*bp->begin(); inst; inst = inst->NextNode()) { + for (Instruction* inst = &*bp->begin(); inst; inst = inst->NextNode()) { if (inst->opcode() != SpvOpLoad) continue; uint32_t varId; - ir::Instruction* ptrInst = GetPtr(inst, &varId); + Instruction* ptrInst = GetPtr(inst, &varId); if (ptrInst->opcode() != SpvOpVariable) continue; if (!IsUniformVar(varId)) continue; if (IsSamplerOrImageVar(varId)) continue; @@ -379,7 +380,7 @@ bool CommonUniformElimPass::CommonUniformLoadElimination(ir::Function* func) { } else { // Copy load into most recent dominating block and remember it replId = TakeNextId(); - std::unique_ptr newLoad(new ir::Instruction( + std::unique_ptr newLoad(new Instruction( context(), SpvOpLoad, inst->type_id(), replId, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}}})); get_def_use_mgr()->AnalyzeInstDefUse(&*newLoad); @@ -400,14 +401,14 @@ bool CommonUniformElimPass::CommonUniformLoadElimination(ir::Function* func) { return modified; } -bool CommonUniformElimPass::CommonUniformLoadElimBlock(ir::Function* func) { +bool CommonUniformElimPass::CommonUniformLoadElimBlock(Function* func) { bool modified = false; for (auto& blk : *func) { uniform2load_id_.clear(); - for (ir::Instruction* inst = &*blk.begin(); inst; inst = inst->NextNode()) { + for (Instruction* inst = &*blk.begin(); inst; inst = inst->NextNode()) { if (inst->opcode() != SpvOpLoad) continue; uint32_t varId; - ir::Instruction* ptrInst = GetPtr(inst, &varId); + Instruction* ptrInst = GetPtr(inst, &varId); if (ptrInst->opcode() != SpvOpVariable) continue; if (!IsUniformVar(varId)) continue; if (!IsSamplerOrImageVar(varId)) continue; @@ -428,7 +429,7 @@ bool CommonUniformElimPass::CommonUniformLoadElimBlock(ir::Function* func) { return modified; } -bool CommonUniformElimPass::CommonExtractElimination(ir::Function* func) { +bool CommonUniformElimPass::CommonExtractElimination(Function* func) { // Find all composite ids with duplicate extracts. for (auto bi = func->begin(); bi != func->end(); ++bi) { for (auto ii = bi->begin(); ii != bi->end(); ++ii) { @@ -451,7 +452,7 @@ bool CommonUniformElimPass::CommonExtractElimination(ir::Function* func) { for (auto idxItr : cItr->second) { if (idxItr.second.size() < 2) continue; uint32_t replId = TakeNextId(); - std::unique_ptr newExtract( + std::unique_ptr newExtract( idxItr.second.front()->Clone(context())); newExtract->SetResultId(replId); get_def_use_mgr()->AnalyzeInstDefUse(&*newExtract); @@ -470,7 +471,7 @@ bool CommonUniformElimPass::CommonExtractElimination(ir::Function* func) { return modified; } -bool CommonUniformElimPass::EliminateCommonUniform(ir::Function* func) { +bool CommonUniformElimPass::EliminateCommonUniform(Function* func) { bool modified = false; modified |= UniformAccessChainConvert(func); modified |= CommonUniformLoadElimination(func); @@ -480,9 +481,7 @@ bool CommonUniformElimPass::EliminateCommonUniform(ir::Function* func) { return modified; } -void CommonUniformElimPass::Initialize(ir::IRContext* c) { - InitializeProcessing(c); - +void CommonUniformElimPass::Initialize() { // Clear collections. comp2idx2inst_.clear(); @@ -519,22 +518,22 @@ Pass::Status CommonUniformElimPass::ProcessImpl() { if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange; // If non-32-bit integer type in module, terminate processing // TODO(): Handle non-32-bit integer constants in access chains - for (const ir::Instruction& inst : get_module()->types_values()) + for (const Instruction& inst : get_module()->types_values()) if (inst.opcode() == SpvOpTypeInt && inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32) return Status::SuccessWithoutChange; // Process entry point functions - ProcessFunction pfn = [this](ir::Function* fp) { + ProcessFunction pfn = [this](Function* fp) { return EliminateCommonUniform(fp); }; bool modified = ProcessEntryPointCallTree(pfn, get_module()); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } -CommonUniformElimPass::CommonUniformElimPass() {} +CommonUniformElimPass::CommonUniformElimPass() = default; -Pass::Status CommonUniformElimPass::Process(ir::IRContext* c) { - Initialize(c); +Pass::Status CommonUniformElimPass::Process() { + Initialize(); return ProcessImpl(); } diff --git a/3rdparty/spirv-tools/source/opt/common_uniform_elim_pass.h b/3rdparty/spirv-tools/source/opt/common_uniform_elim_pass.h index 641520892..e6ef69c5d 100644 --- a/3rdparty/spirv-tools/source/opt/common_uniform_elim_pass.h +++ b/3rdparty/spirv-tools/source/opt/common_uniform_elim_pass.h @@ -14,37 +14,42 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_COMMON_UNIFORM_ELIM_PASS_H_ -#define LIBSPIRV_OPT_COMMON_UNIFORM_ELIM_PASS_H_ +#ifndef SOURCE_OPT_COMMON_UNIFORM_ELIM_PASS_H_ +#define SOURCE_OPT_COMMON_UNIFORM_ELIM_PASS_H_ #include +#include #include +#include #include +#include #include #include #include +#include -#include "basic_block.h" -#include "decoration_manager.h" -#include "def_use_manager.h" -#include "ir_context.h" -#include "module.h" -#include "pass.h" +#include "source/opt/basic_block.h" +#include "source/opt/decoration_manager.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { // See optimizer.hpp for documentation. class CommonUniformElimPass : public Pass { - using cbb_ptr = const ir::BasicBlock*; + using cbb_ptr = const BasicBlock*; public: using GetBlocksFunction = - std::function*(const ir::BasicBlock*)>; + std::function*(const BasicBlock*)>; CommonUniformElimPass(); + const char* name() const override { return "eliminate-common-uniform"; } - Status Process(ir::IRContext*) override; + Status Process() override; private: // Returns true if |opcode| is a non-ptr access chain op @@ -52,7 +57,7 @@ class CommonUniformElimPass : public Pass { // Returns true if |typeInst| is a sampler or image type or a struct // containing one, recursively. - bool IsSamplerOrImageType(const ir::Instruction* typeInst) const; + bool IsSamplerOrImageType(const Instruction* typeInst) const; // Returns true if |varId| is a variable containing a sampler or image. bool IsSamplerOrImageVar(uint32_t varId) const; @@ -60,7 +65,7 @@ class CommonUniformElimPass : public Pass { // Given a load or store pointed at by |ip|, return the top-most // non-CopyObj in its pointer operand. Also return the base pointer // in |objId|. - ir::Instruction* GetPtr(ir::Instruction* ip, uint32_t* objId); + Instruction* GetPtr(Instruction* ip, uint32_t* objId); // Return true if variable is uniform bool IsUniformVar(uint32_t varId); @@ -72,13 +77,12 @@ class CommonUniformElimPass : public Pass { // Given an OpAccessChain instruction, return true // if the accessed variable belongs to a volatile // decorated object or member of a struct type - bool IsAccessChainToVolatileStructType( - const ir::Instruction& AccessChainInst); + bool IsAccessChainToVolatileStructType(const Instruction& AccessChainInst); // Given an OpLoad instruction, return true if // OpLoad has a Volatile Memory Access flag or if // the resulting type is a volatile decorated struct - bool IsVolatileLoad(const ir::Instruction& loadInst); + bool IsVolatileLoad(const Instruction& loadInst); // Return true if any uses of |id| are decorate ops. bool HasUnsupportedDecorates(uint32_t id) const; @@ -87,25 +91,24 @@ class CommonUniformElimPass : public Pass { bool HasOnlyNamesAndDecorates(uint32_t id) const; // Delete inst if it has no uses. Assumes inst has a resultId. - void DeleteIfUseless(ir::Instruction* inst); + void DeleteIfUseless(Instruction* inst); // Replace all instances of load's id with replId and delete load // and its access chain, if any - ir::Instruction* ReplaceAndDeleteLoad(ir::Instruction* loadInst, - uint32_t replId, - ir::Instruction* ptrInst); + Instruction* ReplaceAndDeleteLoad(Instruction* loadInst, uint32_t replId, + Instruction* ptrInst); // For the (constant index) access chain ptrInst, create an // equivalent load and extract - void GenACLoadRepl(const ir::Instruction* ptrInst, - std::vector>* newInsts, + void GenACLoadRepl(const Instruction* ptrInst, + std::vector>* newInsts, uint32_t* resultId); // Return true if all indices are constant - bool IsConstantIndexAccessChain(ir::Instruction* acp); + bool IsConstantIndexAccessChain(Instruction* acp); // Convert all uniform access chain loads into load/extract. - bool UniformAccessChainConvert(ir::Function* func); + bool UniformAccessChainConvert(Function* func); // Compute structured successors for function |func|. // A block's structured successors are the blocks it branches to @@ -117,7 +120,7 @@ class CommonUniformElimPass : public Pass { // // TODO(dnovillo): This pass computes structured successors slightly different // than the implementation in class Pass. Can this be re-factored? - void ComputeStructuredSuccessors(ir::Function* func); + void ComputeStructuredSuccessors(Function* func); // Compute structured block order for |func| into |structuredOrder|. This // order has the property that dominators come before all blocks they @@ -126,24 +129,23 @@ class CommonUniformElimPass : public Pass { // // TODO(dnovillo): This pass computes structured order slightly different // than the implementation in class Pass. Can this be re-factored? - void ComputeStructuredOrder(ir::Function* func, - std::list* order); + void ComputeStructuredOrder(Function* func, std::list* order); // Eliminate loads of uniform variables which have previously been loaded. // If first load is in control flow, move it to first block of function. // Most effective if preceded by UniformAccessChainRemoval(). - bool CommonUniformLoadElimination(ir::Function* func); + bool CommonUniformLoadElimination(Function* func); // Eliminate loads of uniform sampler and image variables which have // previously // been loaded in the same block for types whose loads cannot cross blocks. - bool CommonUniformLoadElimBlock(ir::Function* func); + bool CommonUniformLoadElimBlock(Function* func); // Eliminate duplicated extracts of same id. Extract may be moved to same // block as the id definition. This is primarily intended for extracts // from uniform loads. Most effective if preceded by // CommonUniformLoadElimination(). - bool CommonExtractElimination(ir::Function* func); + bool CommonExtractElimination(Function* func); // For function |func|, first change all uniform constant index // access chain loads into equivalent composite extracts. Then consolidate @@ -157,7 +159,7 @@ class CommonUniformElimPass : public Pass { // is not enabled. It also currently does not support any extensions. // // This function currently only optimizes loads with a single index. - bool EliminateCommonUniform(ir::Function* func); + bool EliminateCommonUniform(Function* func); // Initialize extensions whitelist void InitExtensions(); @@ -170,7 +172,21 @@ class CommonUniformElimPass : public Pass { return (op == SpvOpDecorate || op == SpvOpDecorateId); } - void Initialize(ir::IRContext* c); + // Return true if |inst| is an instruction that loads uniform variable and + // can be replaced with other uniform load instruction. + bool IsUniformLoadToBeRemoved(Instruction* inst) { + if (inst->opcode() == SpvOpLoad) { + uint32_t varId; + Instruction* ptrInst = GetPtr(inst, &varId); + if (ptrInst->opcode() == SpvOpVariable && IsUniformVar(varId) && + !IsSamplerOrImageVar(varId) && + !HasUnsupportedDecorates(inst->result_id()) && !IsVolatileLoad(*inst)) + return true; + } + return false; + } + + void Initialize(); Pass::Status ProcessImpl(); // Map from uniform variable id to its common load id @@ -179,7 +195,7 @@ class CommonUniformElimPass : public Pass { // Map of extract composite ids to map of indices to insts // TODO(greg-lunarg): Consider std::vector. std::unordered_map>> + std::unordered_map>> comp2idx2inst_; // Extensions supported by this pass. @@ -187,11 +203,11 @@ class CommonUniformElimPass : public Pass { // Map from block to its structured successor blocks. See // ComputeStructuredSuccessors() for definition. - std::unordered_map> + std::unordered_map> block2structured_succs_; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_SSAMEM_PASS_H_ +#endif // SOURCE_OPT_COMMON_UNIFORM_ELIM_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/compact_ids_pass.cpp b/3rdparty/spirv-tools/source/opt/compact_ids_pass.cpp index 98a207da7..68b940f1d 100644 --- a/3rdparty/spirv-tools/source/opt/compact_ids_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/compact_ids_pass.cpp @@ -12,25 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "compact_ids_pass.h" -#include "ir_context.h" +#include "source/opt/compact_ids_pass.h" #include #include +#include "source/opt/ir_context.h" + namespace spvtools { namespace opt { -using ir::Instruction; -using ir::Operand; - -Pass::Status CompactIdsPass::Process(ir::IRContext* c) { - InitializeProcessing(c); - +Pass::Status CompactIdsPass::Process() { bool modified = false; std::unordered_map result_id_mapping; - c->module()->ForEachInst( + context()->module()->ForEachInst( [&result_id_mapping, &modified](Instruction* inst) { auto operand = inst->begin(); while (operand != inst->end()) { @@ -64,7 +60,7 @@ Pass::Status CompactIdsPass::Process(ir::IRContext* c) { true); if (modified) - c->module()->SetIdBound( + context()->module()->SetIdBound( static_cast(result_id_mapping.size() + 1)); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; diff --git a/3rdparty/spirv-tools/source/opt/compact_ids_pass.h b/3rdparty/spirv-tools/source/opt/compact_ids_pass.h index cf7c3fb72..d97ae0fa4 100644 --- a/3rdparty/spirv-tools/source/opt/compact_ids_pass.h +++ b/3rdparty/spirv-tools/source/opt/compact_ids_pass.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_COMPACT_IDS_PASS_H_ -#define LIBSPIRV_OPT_COMPACT_IDS_PASS_H_ +#ifndef SOURCE_OPT_COMPACT_IDS_PASS_H_ +#define SOURCE_OPT_COMPACT_IDS_PASS_H_ -#include "ir_context.h" -#include "module.h" -#include "pass.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -26,10 +26,17 @@ namespace opt { class CompactIdsPass : public Pass { public: const char* name() const override { return "compact-ids"; } - Status Process(ir::IRContext*) override; + Status Process() override; + + // Return the mask of preserved Analyses. + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisLoopAnalysis; + } }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_COMPACT_IDS_PASS_H_ +#endif // SOURCE_OPT_COMPACT_IDS_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/composite.cpp b/3rdparty/spirv-tools/source/opt/composite.cpp index 1fbb71f69..2b4dca257 100644 --- a/3rdparty/spirv-tools/source/opt/composite.cpp +++ b/3rdparty/spirv-tools/source/opt/composite.cpp @@ -14,19 +14,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "composite.h" - -#include "ir_context.h" -#include "iterator.h" -#include "spirv/1.2/GLSL.std.450.h" +#include "source/opt/composite.h" #include +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" +#include "spirv/1.2/GLSL.std.450.h" + namespace spvtools { namespace opt { bool ExtInsMatch(const std::vector& extIndices, - const ir::Instruction* insInst, const uint32_t extOffset) { + const Instruction* insInst, const uint32_t extOffset) { uint32_t numIndices = static_cast(extIndices.size()) - extOffset; if (numIndices != insInst->NumInOperands() - 2) return false; for (uint32_t i = 0; i < numIndices; ++i) @@ -36,7 +36,7 @@ bool ExtInsMatch(const std::vector& extIndices, } bool ExtInsConflict(const std::vector& extIndices, - const ir::Instruction* insInst, const uint32_t extOffset) { + const Instruction* insInst, const uint32_t extOffset) { if (extIndices.size() - extOffset == insInst->NumInOperands() - 2) return false; uint32_t extNumIndices = static_cast(extIndices.size()) - extOffset; diff --git a/3rdparty/spirv-tools/source/opt/composite.h b/3rdparty/spirv-tools/source/opt/composite.h index 2153c626c..3cc036e4d 100644 --- a/3rdparty/spirv-tools/source/opt/composite.h +++ b/3rdparty/spirv-tools/source/opt/composite.h @@ -14,19 +14,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_UTIL_COMPOSITE_PASS_H_ -#define LIBSPIRV_UTIL_COMPOSITE_PASS_H_ +#ifndef SOURCE_OPT_COMPOSITE_H_ +#define SOURCE_OPT_COMPOSITE_H_ #include #include #include #include #include +#include -#include "basic_block.h" -#include "def_use_manager.h" -#include "ir_context.h" -#include "module.h" +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { @@ -34,7 +35,7 @@ namespace opt { // Return true if the extract indices in |extIndices| starting at |extOffset| // match indices of insert |insInst|. bool ExtInsMatch(const std::vector& extIndices, - const ir::Instruction* insInst, const uint32_t extOffset); + const Instruction* insInst, const uint32_t extOffset); // Return true if indices in |extIndices| starting at |extOffset| and // indices of insert |insInst| conflict, specifically, if the insert @@ -42,9 +43,9 @@ bool ExtInsMatch(const std::vector& extIndices, // or less bits than the extract specifies, meaning the exact value being // inserted cannot be used to replace the extract. bool ExtInsConflict(const std::vector& extIndices, - const ir::Instruction* insInst, const uint32_t extOffset); + const Instruction* insInst, const uint32_t extOffset); } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_UTIL_COMPOSITE_PASS_H_ +#endif // SOURCE_OPT_COMPOSITE_H_ diff --git a/3rdparty/spirv-tools/source/opt/const_folding_rules.cpp b/3rdparty/spirv-tools/source/opt/const_folding_rules.cpp index a2d23c172..f6013a3d7 100644 --- a/3rdparty/spirv-tools/source/opt/const_folding_rules.cpp +++ b/3rdparty/spirv-tools/source/opt/const_folding_rules.cpp @@ -12,17 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "const_folding_rules.h" +#include "source/opt/const_folding_rules.h" + +#include "source/opt/ir_context.h" namespace spvtools { namespace opt { - namespace { + const uint32_t kExtractCompositeIdInIdx = 0; +// Returns true if |type| is Float or a vector of Float. +bool HasFloatingPoint(const analysis::Type* type) { + if (type->AsFloat()) { + return true; + } else if (const analysis::Vector* vec_type = type->AsVector()) { + return vec_type->element_type()->AsFloat() != nullptr; + } + + return false; +} + // Folds an OpcompositeExtract where input is a composite constant. ConstantFoldingRule FoldExtractWithConstants() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) -> const analysis::Constant* { const analysis::Constant* c = constants[kExtractCompositeIdInIdx]; @@ -34,7 +47,6 @@ ConstantFoldingRule FoldExtractWithConstants() { uint32_t element_index = inst->GetSingleWordInOperand(i); if (c->AsNullConstant()) { // Return Null for the return type. - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {}); @@ -50,7 +62,7 @@ ConstantFoldingRule FoldExtractWithConstants() { } ConstantFoldingRule FoldVectorShuffleWithConstants() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) -> const analysis::Constant* { assert(inst->opcode() == SpvOpVectorShuffle); @@ -60,7 +72,6 @@ ConstantFoldingRule FoldVectorShuffleWithConstants() { return nullptr; } - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* element_type = c1->type()->AsVector()->element_type(); @@ -84,14 +95,18 @@ ConstantFoldingRule FoldVectorShuffleWithConstants() { } std::vector ids; + const uint32_t undef_literal_value = 0xffffffff; for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { uint32_t index = inst->GetSingleWordInOperand(i); - if (index < c1_components.size()) { - ir::Instruction* member_inst = + if (index == undef_literal_value) { + // Don't fold shuffle with undef literal value. + return nullptr; + } else if (index < c1_components.size()) { + Instruction* member_inst = const_mgr->GetDefiningInstruction(c1_components[index]); ids.push_back(member_inst->result_id()); } else { - ir::Instruction* member_inst = const_mgr->GetDefiningInstruction( + Instruction* member_inst = const_mgr->GetDefiningInstruction( c2_components[index - c1_components.size()]); ids.push_back(member_inst->result_id()); } @@ -100,25 +115,111 @@ ConstantFoldingRule FoldVectorShuffleWithConstants() { analysis::TypeManager* type_mgr = context->get_type_mgr(); return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids); }; -} // namespace +} + +ConstantFoldingRule FoldVectorTimesScalar() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + assert(inst->opcode() == SpvOpVectorTimesScalar); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + + if (!inst->IsFloatingPointFoldingAllowed()) { + if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) { + return nullptr; + } + } + + const analysis::Constant* c1 = constants[0]; + const analysis::Constant* c2 = constants[1]; + + if (c1 && c1->IsZero()) { + return c1; + } + + if (c2 && c2->IsZero()) { + // Get or create the NullConstant for this type. + std::vector ids; + return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids); + } + + if (c1 == nullptr || c2 == nullptr) { + return nullptr; + } + + // Check result type. + const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); + const analysis::Vector* vector_type = result_type->AsVector(); + assert(vector_type != nullptr); + const analysis::Type* element_type = vector_type->element_type(); + assert(element_type != nullptr); + const analysis::Float* float_type = element_type->AsFloat(); + assert(float_type != nullptr); + + // Check types of c1 and c2. + assert(c1->type()->AsVector() == vector_type); + assert(c1->type()->AsVector()->element_type() == element_type && + c2->type() == element_type); + + // Get a float vector that is the result of vector-times-scalar. + std::vector c1_components = + c1->GetVectorComponents(const_mgr); + std::vector ids; + if (float_type->width() == 32) { + float scalar = c2->GetFloat(); + for (uint32_t i = 0; i < c1_components.size(); ++i) { + utils::FloatProxy result(c1_components[i]->GetFloat() * scalar); + std::vector words = result.GetWords(); + const analysis::Constant* new_elem = + const_mgr->GetConstant(float_type, words); + ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); + } + return const_mgr->GetConstant(vector_type, ids); + } else if (float_type->width() == 64) { + double scalar = c2->GetDouble(); + for (uint32_t i = 0; i < c1_components.size(); ++i) { + utils::FloatProxy result(c1_components[i]->GetDouble() * + scalar); + std::vector words = result.GetWords(); + const analysis::Constant* new_elem = + const_mgr->GetConstant(float_type, words); + ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); + } + return const_mgr->GetConstant(vector_type, ids); + } + return nullptr; + }; +} ConstantFoldingRule FoldCompositeWithConstants() { // Folds an OpCompositeConstruct where all of the inputs are constants to a // constant. A new constant is created if necessary. - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) -> const analysis::Constant* { - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); + Instruction* type_inst = + context->get_def_use_mgr()->GetDef(inst->type_id()); std::vector ids; - for (const analysis::Constant* element_const : constants) { + for (uint32_t i = 0; i < constants.size(); ++i) { + const analysis::Constant* element_const = constants[i]; if (element_const == nullptr) { return nullptr; } - uint32_t element_id = const_mgr->FindDeclaredConstant(element_const); + + uint32_t component_type_id = 0; + if (type_inst->opcode() == SpvOpTypeStruct) { + component_type_id = type_inst->GetSingleWordInOperand(i); + } else if (type_inst->opcode() == SpvOpTypeArray) { + component_type_id = type_inst->GetSingleWordInOperand(0); + } + + uint32_t element_id = + const_mgr->FindDeclaredConstant(element_const, component_type_id); if (element_id == 0) { return nullptr; } @@ -142,29 +243,6 @@ using BinaryScalarFoldingRule = std::function; -// Returns an std::vector containing the elements of |constant|. The type of -// |constant| must be |Vector|. -std::vector GetVectorComponents( - const analysis::Constant* constant, analysis::ConstantManager* const_mgr) { - std::vector components; - const analysis::VectorConstant* a = constant->AsVectorConstant(); - const analysis::Vector* vector_type = constant->type()->AsVector(); - assert(vector_type != nullptr); - if (a != nullptr) { - for (uint32_t i = 0; i < vector_type->element_count(); ++i) { - components.push_back(a->GetComponents()[i]); - } - } else { - const analysis::Type* element_type = vector_type->element_type(); - const analysis::Constant* element_null_const = - const_mgr->GetConstant(element_type, {}); - for (uint32_t i = 0; i < vector_type->element_count(); ++i) { - components.push_back(element_null_const); - } - } - return components; -} - // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops // using |scalar_rule| and unary float point vectors ops by applying // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| @@ -172,10 +250,9 @@ std::vector GetVectorComponents( // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| // whose element type is |Float| or |Integer|. ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) { - return [scalar_rule](ir::Instruction* inst, + return [scalar_rule](IRContext* context, Instruction* inst, const std::vector& constants) -> const analysis::Constant* { - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); @@ -193,7 +270,7 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) { std::vector a_components; std::vector results_components; - a_components = GetVectorComponents(constants[0], const_mgr); + a_components = constants[0]->GetVectorComponents(const_mgr); // Fold each component of the vector. for (uint32_t i = 0; i < a_components.size(); ++i) { @@ -222,10 +299,9 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) { // that |constants| contains 2 entries. If they are not |nullptr|, then their // type is either |Float| or a |Vector| whose element type is |Float|. ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) { - return [scalar_rule](ir::Instruction* inst, + return [scalar_rule](IRContext* context, Instruction* inst, const std::vector& constants) -> const analysis::Constant* { - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); @@ -244,8 +320,8 @@ ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) { std::vector b_components; std::vector results_components; - a_components = GetVectorComponents(constants[0], const_mgr); - b_components = GetVectorComponents(constants[1], const_mgr); + a_components = constants[0]->GetVectorComponents(const_mgr); + b_components = constants[1]->GetVectorComponents(const_mgr); // Fold each component of the vector. for (uint32_t i = 0; i < a_components.size(); ++i) { @@ -300,7 +376,7 @@ UnaryScalarFoldingRule FoldFToIOp() { }; } -// This macro defines a |UnaryScalarFoldingRule| that performs integer to +// This function defines a |UnaryScalarFoldingRule| that performs integer to // float conversion. // TODO(greg-lunarg): Support for 64-bit integer types. UnaryScalarFoldingRule FoldIToFOp() { @@ -317,14 +393,14 @@ UnaryScalarFoldingRule FoldIToFOp() { float result_val = integer_type->IsSigned() ? static_cast(static_cast(ua)) : static_cast(ua); - spvutils::FloatProxy result(result_val); + utils::FloatProxy result(result_val); std::vector words = {result.data()}; return const_mgr->GetConstant(result_type, words); } else if (float_type->width() == 64) { double result_val = integer_type->IsSigned() ? static_cast(static_cast(ua)) : static_cast(ua); - spvutils::FloatProxy result(result_val); + utils::FloatProxy result(result_val); std::vector words = result.GetWords(); return const_mgr->GetConstant(result_type, words); } @@ -334,28 +410,29 @@ UnaryScalarFoldingRule FoldIToFOp() { // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The // operator |op| must work for both float and double, and use syntax "f1 op f2". -#define FOLD_FPARITH_OP(op) \ - [](const analysis::Type* result_type, const analysis::Constant* a, \ - const analysis::Constant* b, \ - analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \ - assert(result_type != nullptr && a != nullptr && b != nullptr); \ - assert(result_type == a->type() && result_type == b->type()); \ - const analysis::Float* float_type = result_type->AsFloat(); \ - assert(float_type != nullptr); \ - if (float_type->width() == 32) { \ - float fa = a->GetFloat(); \ - float fb = b->GetFloat(); \ - spvutils::FloatProxy result(fa op fb); \ - std::vector words = result.GetWords(); \ - return const_mgr->GetConstant(result_type, words); \ - } else if (float_type->width() == 64) { \ - double fa = a->GetDouble(); \ - double fb = b->GetDouble(); \ - spvutils::FloatProxy result(fa op fb); \ - std::vector words = result.GetWords(); \ - return const_mgr->GetConstant(result_type, words); \ - } \ - return nullptr; \ +#define FOLD_FPARITH_OP(op) \ + [](const analysis::Type* result_type, const analysis::Constant* a, \ + const analysis::Constant* b, \ + analysis::ConstantManager* const_mgr_in_macro) \ + -> const analysis::Constant* { \ + assert(result_type != nullptr && a != nullptr && b != nullptr); \ + assert(result_type == a->type() && result_type == b->type()); \ + const analysis::Float* float_type_in_macro = result_type->AsFloat(); \ + assert(float_type_in_macro != nullptr); \ + if (float_type_in_macro->width() == 32) { \ + float fa = a->GetFloat(); \ + float fb = b->GetFloat(); \ + utils::FloatProxy result_in_macro(fa op fb); \ + std::vector words_in_macro = result_in_macro.GetWords(); \ + return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \ + } else if (float_type_in_macro->width() == 64) { \ + double fa = a->GetDouble(); \ + double fb = b->GetDouble(); \ + utils::FloatProxy result_in_macro(fa op fb); \ + std::vector words_in_macro = result_in_macro.GetWords(); \ + return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \ + } \ + return nullptr; \ } // Define the folding rule for conversion between floating point and integer @@ -447,9 +524,262 @@ ConstantFoldingRule FoldFOrdGreaterThanEqual() { ConstantFoldingRule FoldFUnordGreaterThanEqual() { return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false)); } + +// Folds an OpDot where all of the inputs are constants to a +// constant. A new constant is created if necessary. +ConstantFoldingRule FoldOpDotWithConstants() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); + assert(new_type->AsFloat() && "OpDot should have a float return type."); + const analysis::Float* float_type = new_type->AsFloat(); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return nullptr; + } + + // If one of the operands is 0, then the result is 0. + bool has_zero_operand = false; + + for (int i = 0; i < 2; ++i) { + if (constants[i]) { + if (constants[i]->AsNullConstant() || + constants[i]->AsVectorConstant()->IsZero()) { + has_zero_operand = true; + break; + } + } + } + + if (has_zero_operand) { + if (float_type->width() == 32) { + utils::FloatProxy result(0.0f); + std::vector words = result.GetWords(); + return const_mgr->GetConstant(float_type, words); + } + if (float_type->width() == 64) { + utils::FloatProxy result(0.0); + std::vector words = result.GetWords(); + return const_mgr->GetConstant(float_type, words); + } + return nullptr; + } + + if (constants[0] == nullptr || constants[1] == nullptr) { + return nullptr; + } + + std::vector a_components; + std::vector b_components; + + a_components = constants[0]->GetVectorComponents(const_mgr); + b_components = constants[1]->GetVectorComponents(const_mgr); + + utils::FloatProxy result(0.0); + std::vector words = result.GetWords(); + const analysis::Constant* result_const = + const_mgr->GetConstant(float_type, words); + for (uint32_t i = 0; i < a_components.size(); ++i) { + if (a_components[i] == nullptr || b_components[i] == nullptr) { + return nullptr; + } + + const analysis::Constant* component = FOLD_FPARITH_OP(*)( + new_type, a_components[i], b_components[i], const_mgr); + result_const = + FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr); + } + return result_const; + }; +} + +// This function defines a |UnaryScalarFoldingRule| that subtracts the constant +// from zero. +UnaryScalarFoldingRule FoldFNegateOp() { + return [](const analysis::Type* result_type, const analysis::Constant* a, + analysis::ConstantManager* const_mgr) -> const analysis::Constant* { + assert(result_type != nullptr && a != nullptr); + assert(result_type == a->type()); + const analysis::Float* float_type = result_type->AsFloat(); + assert(float_type != nullptr); + if (float_type->width() == 32) { + float fa = a->GetFloat(); + utils::FloatProxy result(-fa); + std::vector words = result.GetWords(); + return const_mgr->GetConstant(result_type, words); + } else if (float_type->width() == 64) { + double da = a->GetDouble(); + utils::FloatProxy result(-da); + std::vector words = result.GetWords(); + return const_mgr->GetConstant(result_type, words); + } + return nullptr; + }; +} + +ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); } + +ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) { + return [cmp_opcode](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return nullptr; + } + + uint32_t non_const_idx = (constants[0] ? 1 : 0); + uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx); + Instruction* operand_inst = def_use_mgr->GetDef(operand_id); + + analysis::TypeManager* type_mgr = context->get_type_mgr(); + const analysis::Type* operand_type = + type_mgr->GetType(operand_inst->type_id()); + + if (!operand_type->AsFloat()) { + return nullptr; + } + + if (operand_type->AsFloat()->width() != 32 && + operand_type->AsFloat()->width() != 64) { + return nullptr; + } + + if (operand_inst->opcode() != SpvOpExtInst) { + return nullptr; + } + + if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) { + return nullptr; + } + + if (constants[1] == nullptr && constants[0] == nullptr) { + return nullptr; + } + + uint32_t max_id = operand_inst->GetSingleWordInOperand(4); + const analysis::Constant* max_const = + const_mgr->FindDeclaredConstant(max_id); + + uint32_t min_id = operand_inst->GetSingleWordInOperand(3); + const analysis::Constant* min_const = + const_mgr->FindDeclaredConstant(min_id); + + bool found_result = false; + bool result = false; + + switch (cmp_opcode) { + case SpvOpFOrdLessThan: + case SpvOpFUnordLessThan: + case SpvOpFOrdGreaterThanEqual: + case SpvOpFUnordGreaterThanEqual: + if (constants[0]) { + if (min_const) { + if (constants[0]->GetValueAsDouble() < + min_const->GetValueAsDouble()) { + found_result = true; + result = (cmp_opcode == SpvOpFOrdLessThan || + cmp_opcode == SpvOpFUnordLessThan); + } + } + if (max_const) { + if (constants[0]->GetValueAsDouble() >= + max_const->GetValueAsDouble()) { + found_result = true; + result = !(cmp_opcode == SpvOpFOrdLessThan || + cmp_opcode == SpvOpFUnordLessThan); + } + } + } + + if (constants[1]) { + if (max_const) { + if (max_const->GetValueAsDouble() < + constants[1]->GetValueAsDouble()) { + found_result = true; + result = (cmp_opcode == SpvOpFOrdLessThan || + cmp_opcode == SpvOpFUnordLessThan); + } + } + + if (min_const) { + if (min_const->GetValueAsDouble() >= + constants[1]->GetValueAsDouble()) { + found_result = true; + result = !(cmp_opcode == SpvOpFOrdLessThan || + cmp_opcode == SpvOpFUnordLessThan); + } + } + } + break; + case SpvOpFOrdGreaterThan: + case SpvOpFUnordGreaterThan: + case SpvOpFOrdLessThanEqual: + case SpvOpFUnordLessThanEqual: + if (constants[0]) { + if (min_const) { + if (constants[0]->GetValueAsDouble() <= + min_const->GetValueAsDouble()) { + found_result = true; + result = (cmp_opcode == SpvOpFOrdLessThanEqual || + cmp_opcode == SpvOpFUnordLessThanEqual); + } + } + if (max_const) { + if (constants[0]->GetValueAsDouble() > + max_const->GetValueAsDouble()) { + found_result = true; + result = !(cmp_opcode == SpvOpFOrdLessThanEqual || + cmp_opcode == SpvOpFUnordLessThanEqual); + } + } + } + + if (constants[1]) { + if (max_const) { + if (max_const->GetValueAsDouble() <= + constants[1]->GetValueAsDouble()) { + found_result = true; + result = (cmp_opcode == SpvOpFOrdLessThanEqual || + cmp_opcode == SpvOpFUnordLessThanEqual); + } + } + + if (min_const) { + if (min_const->GetValueAsDouble() > + constants[1]->GetValueAsDouble()) { + found_result = true; + result = !(cmp_opcode == SpvOpFOrdLessThanEqual || + cmp_opcode == SpvOpFUnordLessThanEqual); + } + } + } + break; + default: + return nullptr; + } + + if (!found_result) { + return nullptr; + } + + const analysis::Type* bool_type = + context->get_type_mgr()->GetType(inst->type_id()); + const analysis::Constant* result_const = + const_mgr->GetConstant(bool_type, {static_cast(result)}); + assert(result_const); + return result_const; + }; +} + } // namespace -spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() { +ConstantFoldingRules::ConstantFoldingRules() { // Add all folding rules to the list for the opcodes to which they apply. // Note that the order in which rules are added to the list matters. If a rule // applies to the instruction, the rest of the rules will not be attempted. @@ -464,25 +794,56 @@ spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() { rules_[SpvOpConvertSToF].push_back(FoldIToF()); rules_[SpvOpConvertUToF].push_back(FoldIToF()); + rules_[SpvOpDot].push_back(FoldOpDotWithConstants()); rules_[SpvOpFAdd].push_back(FoldFAdd()); rules_[SpvOpFDiv].push_back(FoldFDiv()); rules_[SpvOpFMul].push_back(FoldFMul()); rules_[SpvOpFSub].push_back(FoldFSub()); rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual()); + rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual()); + rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual()); + rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual()); + rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan()); + rules_[SpvOpFOrdLessThan].push_back( + FoldFClampFeedingCompare(SpvOpFOrdLessThan)); + rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan()); + rules_[SpvOpFUnordLessThan].push_back( + FoldFClampFeedingCompare(SpvOpFUnordLessThan)); + rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan()); + rules_[SpvOpFOrdGreaterThan].push_back( + FoldFClampFeedingCompare(SpvOpFOrdGreaterThan)); + rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan()); + rules_[SpvOpFUnordGreaterThan].push_back( + FoldFClampFeedingCompare(SpvOpFUnordGreaterThan)); + rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual()); + rules_[SpvOpFOrdLessThanEqual].push_back( + FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual)); + rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual()); + rules_[SpvOpFUnordLessThanEqual].push_back( + FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual)); + rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual()); + rules_[SpvOpFOrdGreaterThanEqual].push_back( + FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual)); + rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual()); + rules_[SpvOpFUnordGreaterThanEqual].push_back( + FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual)); rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants()); + rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar()); + + rules_[SpvOpFNegate].push_back(FoldFNegate()); } } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/const_folding_rules.h b/3rdparty/spirv-tools/source/opt/const_folding_rules.h index 2d9ecbaa3..c1865792b 100644 --- a/3rdparty/spirv-tools/source/opt/const_folding_rules.h +++ b/3rdparty/spirv-tools/source/opt/const_folding_rules.h @@ -12,17 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_CONST_FOLDING_RULES_H_ -#define LIBSPIRV_OPT_CONST_FOLDING_RULES_H_ +#ifndef SOURCE_OPT_CONST_FOLDING_RULES_H_ +#define SOURCE_OPT_CONST_FOLDING_RULES_H_ +#include #include -#include "constants.h" -#include "def_use_manager.h" -#include "folding_rules.h" -#include "ir_builder.h" -#include "ir_context.h" -#include "latest_version_spirv_header.h" +#include "source/opt/constants.h" namespace spvtools { namespace opt { @@ -53,7 +49,7 @@ namespace opt { // fold an instruction, the later rules will not be attempted. using ConstantFoldingRule = std::function& constants)>; class ConstantFoldingRules { @@ -81,4 +77,4 @@ class ConstantFoldingRules { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_CONST_FOLDING_RULES_H_ +#endif // SOURCE_OPT_CONST_FOLDING_RULES_H_ diff --git a/3rdparty/spirv-tools/source/opt/constants.cpp b/3rdparty/spirv-tools/source/opt/constants.cpp index 1eb9efe14..ecb5f97c4 100644 --- a/3rdparty/spirv-tools/source/opt/constants.cpp +++ b/3rdparty/spirv-tools/source/opt/constants.cpp @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "constants.h" -#include "ir_context.h" +#include "source/opt/constants.h" #include #include +#include "source/opt/ir_context.h" + namespace spvtools { namespace opt { namespace analysis { @@ -44,6 +45,16 @@ double Constant::GetDouble() const { } } +double Constant::GetValueAsDouble() const { + assert(type()->AsFloat() != nullptr); + if (type()->AsFloat()->width() == 32) { + return GetFloat(); + } else { + assert(type()->AsFloat()->width() == 64); + return GetDouble(); + } +} + uint32_t Constant::GetU32() const { assert(type()->AsInteger() != nullptr); assert(type()->AsInteger()->width() == 32); @@ -92,7 +103,7 @@ int64_t Constant::GetS64() const { } } -ConstantManager::ConstantManager(ir::IRContext* ctx) : ctx_(ctx) { +ConstantManager::ConstantManager(IRContext* ctx) : ctx_(ctx) { // Populate the constant table with values from constant declarations in the // module. The values of each OpConstant declaration is the identity // assignment (i.e., each constant is its own value). @@ -101,15 +112,15 @@ ConstantManager::ConstantManager(ir::IRContext* ctx) : ctx_(ctx) { } } -Type* ConstantManager::GetType(const ir::Instruction* inst) const { +Type* ConstantManager::GetType(const Instruction* inst) const { return context()->get_type_mgr()->GetType(inst->type_id()); } std::vector ConstantManager::GetOperandConstants( - ir::Instruction* inst) const { + Instruction* inst) const { std::vector constants; for (uint32_t i = 0; i < inst->NumInOperands(); i++) { - const ir::Operand* operand = &inst->GetInOperand(i); + const Operand* operand = &inst->GetInOperand(i); if (operand->type != SPV_OPERAND_TYPE_ID) { constants.push_back(nullptr); } else { @@ -121,6 +132,24 @@ std::vector ConstantManager::GetOperandConstants( return constants; } +uint32_t ConstantManager::FindDeclaredConstant(const Constant* c, + uint32_t type_id) const { + c = FindConstant(c); + if (c == nullptr) { + return 0; + } + + for (auto range = const_val_to_id_.equal_range(c); + range.first != range.second; ++range.first) { + Instruction* const_def = + context()->get_def_use_mgr()->GetDef(range.first->second); + if (type_id == 0 || const_def->type_id() == type_id) { + return range.first->second; + } + } + return 0; +} + std::vector ConstantManager::GetConstantsFromIds( const std::vector& ids) const { std::vector constants; @@ -134,9 +163,8 @@ std::vector ConstantManager::GetConstantsFromIds( return constants; } -ir::Instruction* ConstantManager::BuildInstructionAndAddToModule( - const Constant* new_const, ir::Module::inst_iterator* pos, - uint32_t type_id) { +Instruction* ConstantManager::BuildInstructionAndAddToModule( + const Constant* new_const, Module::inst_iterator* pos, uint32_t type_id) { uint32_t new_id = context()->TakeNextId(); auto new_inst = CreateInstruction(new_id, new_const, type_id); if (!new_inst) { @@ -150,33 +178,37 @@ ir::Instruction* ConstantManager::BuildInstructionAndAddToModule( return new_inst_ptr; } -ir::Instruction* ConstantManager::GetDefiningInstruction( - const Constant* c, ir::Module::inst_iterator* pos) { - uint32_t decl_id = FindDeclaredConstant(c); +Instruction* ConstantManager::GetDefiningInstruction( + const Constant* c, uint32_t type_id, Module::inst_iterator* pos) { + assert(type_id == 0 || + context()->get_type_mgr()->GetType(type_id) == c->type()); + uint32_t decl_id = FindDeclaredConstant(c, type_id); if (decl_id == 0) { auto iter = context()->types_values_end(); if (pos == nullptr) pos = &iter; - return BuildInstructionAndAddToModule(c, pos); + return BuildInstructionAndAddToModule(c, pos, type_id); } else { auto def = context()->get_def_use_mgr()->GetDef(decl_id); assert(def != nullptr); + assert((type_id == 0 || def->type_id() == type_id) && + "This constant already has an instruction with a different type."); return def; } } -const Constant* ConstantManager::CreateConstant( +std::unique_ptr ConstantManager::CreateConstant( const Type* type, const std::vector& literal_words_or_ids) const { if (literal_words_or_ids.size() == 0) { // Constant declared with OpConstantNull - return new NullConstant(type); + return MakeUnique(type); } else if (auto* bt = type->AsBool()) { assert(literal_words_or_ids.size() == 1 && "Bool constant should be declared with one operand"); - return new BoolConstant(bt, literal_words_or_ids.front()); + return MakeUnique(bt, literal_words_or_ids.front()); } else if (auto* it = type->AsInteger()) { - return new IntConstant(it, literal_words_or_ids); + return MakeUnique(it, literal_words_or_ids); } else if (auto* ft = type->AsFloat()) { - return new FloatConstant(ft, literal_words_or_ids); + return MakeUnique(ft, literal_words_or_ids); } else if (auto* vt = type->AsVector()) { auto components = GetConstantsFromIds(literal_words_or_ids); if (components.empty()) return nullptr; @@ -199,25 +231,25 @@ const Constant* ConstantManager::CreateConstant( return false; })) return nullptr; - return new VectorConstant(vt, components); + return MakeUnique(vt, components); } else if (auto* mt = type->AsMatrix()) { auto components = GetConstantsFromIds(literal_words_or_ids); if (components.empty()) return nullptr; - return new MatrixConstant(mt, components); + return MakeUnique(mt, components); } else if (auto* st = type->AsStruct()) { auto components = GetConstantsFromIds(literal_words_or_ids); if (components.empty()) return nullptr; - return new StructConstant(st, components); + return MakeUnique(st, components); } else if (auto* at = type->AsArray()) { auto components = GetConstantsFromIds(literal_words_or_ids); if (components.empty()) return nullptr; - return new ArrayConstant(at, components); + return MakeUnique(at, components); } else { return nullptr; } } -const Constant* ConstantManager::GetConstantFromInst(ir::Instruction* inst) { +const Constant* ConstantManager::GetConstantFromInst(Instruction* inst) { std::vector literal_words_or_ids; // Collect the constant defining literals or component ids. @@ -248,31 +280,30 @@ const Constant* ConstantManager::GetConstantFromInst(ir::Instruction* inst) { return GetConstant(GetType(inst), literal_words_or_ids); } -std::unique_ptr ConstantManager::CreateInstruction( +std::unique_ptr ConstantManager::CreateInstruction( uint32_t id, const Constant* c, uint32_t type_id) const { uint32_t type = (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id; if (c->AsNullConstant()) { - return MakeUnique(context(), SpvOp::SpvOpConstantNull, - type, id, - std::initializer_list{}); + return MakeUnique(context(), SpvOp::SpvOpConstantNull, type, + id, std::initializer_list{}); } else if (const BoolConstant* bc = c->AsBoolConstant()) { - return MakeUnique( + return MakeUnique( context(), bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse, - type, id, std::initializer_list{}); + type, id, std::initializer_list{}); } else if (const IntConstant* ic = c->AsIntConstant()) { - return MakeUnique( + return MakeUnique( context(), SpvOp::SpvOpConstant, type, id, - std::initializer_list{ir::Operand( - spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, - ic->words())}); + std::initializer_list{ + Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, + ic->words())}); } else if (const FloatConstant* fc = c->AsFloatConstant()) { - return MakeUnique( + return MakeUnique( context(), SpvOp::SpvOpConstant, type, id, - std::initializer_list{ir::Operand( - spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, - fc->words())}); + std::initializer_list{ + Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, + fc->words())}); } else if (const CompositeConstant* cc = c->AsCompositeConstant()) { return CreateCompositeInstruction(id, cc, type_id); } else { @@ -280,11 +311,20 @@ std::unique_ptr ConstantManager::CreateInstruction( } } -std::unique_ptr ConstantManager::CreateCompositeInstruction( +std::unique_ptr ConstantManager::CreateCompositeInstruction( uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const { - std::vector operands; + std::vector operands; + Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id); + uint32_t component_index = 0; for (const Constant* component_const : cc->GetComponents()) { - uint32_t id = FindDeclaredConstant(component_const); + uint32_t component_type_id = 0; + if (type_inst && type_inst->opcode() == SpvOpTypeStruct) { + component_type_id = type_inst->GetSingleWordInOperand(component_index); + } else if (type_inst && type_inst->opcode() == SpvOpTypeArray) { + component_type_id = type_inst->GetSingleWordInOperand(0); + } + uint32_t id = FindDeclaredConstant(component_const, component_type_id); + if (id == 0) { // Cannot get the id of the component constant, while all components // should have been added to the module prior to the composite constant. @@ -293,17 +333,39 @@ std::unique_ptr ConstantManager::CreateCompositeInstruction( } operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID, std::initializer_list{id}); + component_index++; } uint32_t type = (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id; - return MakeUnique(context(), SpvOp::SpvOpConstantComposite, - type, result_id, std::move(operands)); + return MakeUnique(context(), SpvOp::SpvOpConstantComposite, type, + result_id, std::move(operands)); } const Constant* ConstantManager::GetConstant( const Type* type, const std::vector& literal_words_or_ids) { auto cst = CreateConstant(type, literal_words_or_ids); - return cst ? RegisterConstant(cst) : nullptr; + return cst ? RegisterConstant(std::move(cst)) : nullptr; +} + +std::vector Constant::GetVectorComponents( + analysis::ConstantManager* const_mgr) const { + std::vector components; + const analysis::VectorConstant* a = this->AsVectorConstant(); + const analysis::Vector* vector_type = this->type()->AsVector(); + assert(vector_type != nullptr); + if (a != nullptr) { + for (uint32_t i = 0; i < vector_type->element_count(); ++i) { + components.push_back(a->GetComponents()[i]); + } + } else { + const analysis::Type* element_type = vector_type->element_type(); + const analysis::Constant* element_null_const = + const_mgr->GetConstant(element_type, {}); + for (uint32_t i = 0; i < vector_type->element_count(); ++i) { + components.push_back(element_null_const); + } + } + return components; } } // namespace analysis diff --git a/3rdparty/spirv-tools/source/opt/constants.h b/3rdparty/spirv-tools/source/opt/constants.h index d1c1fbeaa..de2dfc3d0 100644 --- a/3rdparty/spirv-tools/source/opt/constants.h +++ b/3rdparty/spirv-tools/source/opt/constants.h @@ -12,25 +12,28 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_CONSTANTS_H_ -#define LIBSPIRV_OPT_CONSTANTS_H_ +#ifndef SOURCE_OPT_CONSTANTS_H_ +#define SOURCE_OPT_CONSTANTS_H_ -#include #include +#include #include #include #include #include #include -#include "make_unique.h" -#include "module.h" -#include "type_manager.h" -#include "types.h" -#include "util/hex_float.h" +#include "source/opt/module.h" +#include "source/opt/type_manager.h" +#include "source/opt/types.h" +#include "source/util/hex_float.h" +#include "source/util/make_unique.h" namespace spvtools { namespace opt { + +class IRContext; + namespace analysis { // Class hierarchy to represent the normal constants defined through @@ -48,6 +51,7 @@ class VectorConstant; class MatrixConstant; class ArrayConstant; class NullConstant; +class ConstantManager; // Abstract class for a SPIR-V constant. It has a bunch of As methods, // which is used as a way to probe the actual @@ -92,6 +96,10 @@ class Constant { // Float type. double GetDouble() const; + // Returns the double representation of the constant. Must be a 32-bit or + // 64-bit Float type. + double GetValueAsDouble() const; + // Returns uint32_t representation of the constant. Must be a 32 bit // Integer type. uint32_t GetU32() const; @@ -108,8 +116,16 @@ class Constant { // Integer type. int64_t GetS64() const; + // Returns true if the constant is a zero or a composite containing 0s. + virtual bool IsZero() const { return false; } + const Type* type() const { return type_; } + // Returns an std::vector containing the elements of |constant|. The type of + // |constant| must be |Vector|. + std::vector GetVectorComponents( + ConstantManager* const_mgr) const; + protected: Constant(const Type* ty) : type_(ty) {} @@ -128,7 +144,7 @@ class ScalarConstant : public Constant { virtual const std::vector& words() const { return words_; } // Returns true if the value is zero. - bool IsZero() const { + bool IsZero() const override { bool is_zero = true; for (uint32_t v : words()) { if (v != 0) { @@ -221,7 +237,7 @@ class FloatConstant : public ScalarConstant { float GetFloatValue() const { assert(type()->AsFloat()->width() == 32 && "Not a 32-bit floating point value."); - spvutils::FloatProxy a(words()[0]); + utils::FloatProxy a(words()[0]); return a.getAsFloat(); } @@ -233,7 +249,7 @@ class FloatConstant : public ScalarConstant { uint64_t combined_words = words()[1]; combined_words = combined_words << 32; combined_words |= words()[0]; - spvutils::FloatProxy a(combined_words); + utils::FloatProxy a(combined_words); return a.getAsFloat(); } }; @@ -274,6 +290,15 @@ class CompositeConstant : public Constant { return components_; } + bool IsZero() const override { + for (const Constant* c : GetComponents()) { + if (!c->IsZero()) { + return false; + } + } + return true; + } + protected: CompositeConstant(const Type* ty) : Constant(ty), components_() {} CompositeConstant(const Type* ty, @@ -407,10 +432,9 @@ class NullConstant : public Constant { std::unique_ptr Copy() const override { return std::unique_ptr(CopyNullConstant().release()); } + bool IsZero() const override { return true; }; }; -class IRContext; - // Hash function for Constant instances. Use the structure of the constant as // the key. struct ConstantHash { @@ -457,10 +481,11 @@ struct ConstantEqual { const auto& composite2 = c2->AsCompositeConstant(); return composite2 && composite1->GetComponents() == composite2->GetComponents(); - } else if (c1->AsNullConstant()) + } else if (c1->AsNullConstant()) { return c2->AsNullConstant() != nullptr; - else + } else { assert(false && "Tried to compare two invalid Constant instances."); + } return false; } }; @@ -468,9 +493,9 @@ struct ConstantEqual { // This class represents a pool of constants. class ConstantManager { public: - ConstantManager(ir::IRContext* ctx); + ConstantManager(IRContext* ctx); - ir::IRContext* context() const { return ctx_; } + IRContext* context() const { return ctx_; } // Gets or creates a unique Constant instance of type |type| and a vector of // constant defining words |words|. If a Constant instance existed already in @@ -480,10 +505,16 @@ class ConstantManager { const Constant* GetConstant( const Type* type, const std::vector& literal_words_or_ids); + template + const Constant* GetConstant(const Type* type, const C& literal_words_or_ids) { + return GetConstant(type, std::vector(literal_words_or_ids.begin(), + literal_words_or_ids.end())); + } + // Gets or creates a Constant instance to hold the constant value of the given // instruction. It returns a pointer to a Constant instance or nullptr if it // could not create the constant. - const Constant* GetConstantFromInst(ir::Instruction* inst); + const Constant* GetConstantFromInst(Instruction* inst); // Gets or creates a constant defining instruction for the given Constant |c|. // If |c| had already been defined, it returns a pointer to the existing @@ -491,8 +522,16 @@ class ConstantManager { // optional |pos| is given, it will insert any newly created instructions at // the given instruction iterator position. Otherwise, it inserts the new // instruction at the end of the current module's types section. - ir::Instruction* GetDefiningInstruction( - const Constant* c, ir::Module::inst_iterator* pos = nullptr); + // + // |type_id| is an optional argument for disambiguating equivalent types. If + // |type_id| is specified, it is used as the type of the constant when a new + // instruction is created. Otherwise the type of the constant is derived by + // getting an id from the type manager for |c|. + // + // When |type_id| is not zero, the type of |c| must be the type returned by + // type manager when given |type_id|. + Instruction* GetDefiningInstruction(const Constant* c, uint32_t type_id = 0, + Module::inst_iterator* pos = nullptr); // Creates a constant defining instruction for the given Constant instance // and inserts the instruction at the position specified by the given @@ -506,12 +545,13 @@ class ConstantManager { // |type_id| is specified, it is used as the type of the constant. Otherwise // the type of the constant is derived by getting an id from the type manager // for |c|. - ir::Instruction* BuildInstructionAndAddToModule( - const Constant* c, ir::Module::inst_iterator* pos, uint32_t type_id = 0); + Instruction* BuildInstructionAndAddToModule(const Constant* c, + Module::inst_iterator* pos, + uint32_t type_id = 0); // A helper function to get the result type of the given instruction. Returns // nullptr if the instruction does not have a type id (type id is 0). - Type* GetType(const ir::Instruction* inst) const; + Type* GetType(const Instruction* inst) const; // A helper function to get the collected normal constant with the given id. // Returns the pointer to the Constant instance in case it is found. @@ -523,13 +563,13 @@ class ConstantManager { // A helper function to get the id of a collected constant with the pointer // to the Constant instance. Returns 0 in case the constant is not found. - uint32_t FindDeclaredConstant(const Constant* c) const { - auto iter = const_val_to_id_.find(c); - return (iter != const_val_to_id_.end()) ? iter->second : 0; - } + uint32_t FindDeclaredConstant(const Constant* c, uint32_t type_id) const; // Returns the canonical constant that has the same structure and value as the // given Constant |cst|. If none is found, it returns nullptr. + // + // TODO: Should be able to give a type id to disambiguate types with the same + // structure. const Constant* FindConstant(const Constant* c) const { auto it = const_pool_.find(c); return (it != const_pool_.end()) ? *it : nullptr; @@ -538,8 +578,11 @@ class ConstantManager { // Registers a new constant |cst| in the constant pool. If the constant // existed already, it returns a pointer to the previously existing Constant // in the pool. Otherwise, it returns |cst|. - const Constant* RegisterConstant(const Constant* cst) { - auto ret = const_pool_.insert(cst); + const Constant* RegisterConstant(std::unique_ptr cst) { + auto ret = const_pool_.insert(cst.get()); + if (ret.second) { + owned_constants_.emplace_back(std::move(cst)); + } return *ret.first; } @@ -551,12 +594,12 @@ class ConstantManager { // Returns a vector of constants representing each in operand. If an operand // is not constant its entry is nullptr. - std::vector GetOperandConstants(ir::Instruction* inst) const; + std::vector GetOperandConstants(Instruction* inst) const; // Records a mapping between |inst| and the constant value generated by it. // It returns true if a new Constant was successfully mapped, false if |inst| // generates no constant values. - bool MapInst(ir::Instruction* inst) { + bool MapInst(Instruction* inst) { if (auto cst = GetConstantFromInst(inst)) { MapConstantToInst(cst, inst); return true; @@ -574,9 +617,10 @@ class ConstantManager { // Records a new mapping between |inst| and |const_value|. This updates the // two mappings |id_to_const_val_| and |const_val_to_id_|. - void MapConstantToInst(const Constant* const_value, ir::Instruction* inst) { - const_val_to_id_[const_value] = inst->result_id(); - id_to_const_val_[inst->result_id()] = const_value; + void MapConstantToInst(const Constant* const_value, Instruction* inst) { + if (id_to_const_val_.insert({inst->result_id(), const_value}).second) { + const_val_to_id_.insert({const_value, inst->result_id()}); + } } private: @@ -592,7 +636,7 @@ class ConstantManager { // type, either Bool, Integer or Float. If any of the rules above failed, the // creation will fail and nullptr will be returned. If the vector is empty, // a NullConstant instance will be created with the given type. - const Constant* CreateConstant( + std::unique_ptr CreateConstant( const Type* type, const std::vector& literal_words_or_ids) const; @@ -605,8 +649,9 @@ class ConstantManager { // |type_id| is specified, it is used as the type of the constant. Otherwise // the type of the constant is derived by getting an id from the type manager // for |c|. - std::unique_ptr CreateInstruction( - uint32_t result_id, const Constant* c, uint32_t type_id = 0) const; + std::unique_ptr CreateInstruction(uint32_t result_id, + const Constant* c, + uint32_t type_id = 0) const; // Creates an OpConstantComposite instruction with the given result id and // the CompositeConst instance which represents a composite constant. Returns @@ -617,12 +662,12 @@ class ConstantManager { // |type_id| is specified, it is used as the type of the constant. Otherwise // the type of the constant is derived by getting an id from the type manager // for |c|. - std::unique_ptr CreateCompositeInstruction( + std::unique_ptr CreateCompositeInstruction( uint32_t result_id, const CompositeConstant* cc, uint32_t type_id = 0) const; // IR context that owns this constant manager. - ir::IRContext* ctx_; + IRContext* ctx_; // A mapping from the result ids of Normal Constants to their // Constant instances. All Normal Constants in the module, either @@ -634,14 +679,18 @@ class ConstantManager { // result id in the module. This is a mirror map of |id_to_const_val_|. All // Normal Constants that defining instructions in the module should have // their Constant and their result id registered here. - std::unordered_map const_val_to_id_; + std::multimap const_val_to_id_; // The constant pool. All created constants are registered here. std::unordered_set const_pool_; + + // The constant that are owned by the constant manager. Every constant in + // |const_pool_| should be in |owned_constants_| as well. + std::vector> owned_constants_; }; } // namespace analysis } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_CONSTANTS_H_ +#endif // SOURCE_OPT_CONSTANTS_H_ diff --git a/3rdparty/spirv-tools/source/opt/copy_prop_arrays.cpp b/3rdparty/spirv-tools/source/opt/copy_prop_arrays.cpp index eb4694f22..028b237d7 100644 --- a/3rdparty/spirv-tools/source/opt/copy_prop_arrays.cpp +++ b/3rdparty/spirv-tools/source/opt/copy_prop_arrays.cpp @@ -12,25 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "copy_prop_arrays.h" -#include "ir_builder.h" +#include "source/opt/copy_prop_arrays.h" +#include + +#include "source/opt/ir_builder.h" + +namespace spvtools { +namespace opt { namespace { + const uint32_t kLoadPointerInOperand = 0; const uint32_t kStorePointerInOperand = 0; const uint32_t kStoreObjectInOperand = 1; const uint32_t kCompositeExtractObjectInOperand = 0; + } // namespace -namespace spvtools { -namespace opt { - -Pass::Status CopyPropagateArrays::Process(ir::IRContext* ctx) { - InitializeProcessing(ctx); - +Pass::Status CopyPropagateArrays::Process() { bool modified = false; - for (ir::Function& function : *get_module()) { - ir::BasicBlock* entry_bb = &*function.begin(); + for (Function& function : *get_module()) { + BasicBlock* entry_bb = &*function.begin(); for (auto var_inst = entry_bb->begin(); var_inst->opcode() == SpvOpVariable; ++var_inst) { @@ -39,7 +41,7 @@ Pass::Status CopyPropagateArrays::Process(ir::IRContext* ctx) { } // Find the only store to the entire memory location, if it exists. - ir::Instruction* store_inst = FindStoreInstruction(&*var_inst); + Instruction* store_inst = FindStoreInstruction(&*var_inst); if (!store_inst) { continue; @@ -60,8 +62,8 @@ Pass::Status CopyPropagateArrays::Process(ir::IRContext* ctx) { } std::unique_ptr -CopyPropagateArrays::FindSourceObjectIfPossible(ir::Instruction* var_inst, - ir::Instruction* store_inst) { +CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst, + Instruction* store_inst) { assert(var_inst->opcode() == SpvOpVariable && "Expecting a variable."); // Check that the variable is a composite object where |store_inst| @@ -95,11 +97,11 @@ CopyPropagateArrays::FindSourceObjectIfPossible(ir::Instruction* var_inst, return source; } -ir::Instruction* CopyPropagateArrays::FindStoreInstruction( - const ir::Instruction* var_inst) const { - ir::Instruction* store_inst = nullptr; +Instruction* CopyPropagateArrays::FindStoreInstruction( + const Instruction* var_inst) const { + Instruction* store_inst = nullptr; get_def_use_mgr()->WhileEachUser( - var_inst, [&store_inst, var_inst](ir::Instruction* use) { + var_inst, [&store_inst, var_inst](Instruction* use) { if (use->opcode() == SpvOpStore && use->GetSingleWordInOperand(kStorePointerInOperand) == var_inst->result_id()) { @@ -115,24 +117,23 @@ ir::Instruction* CopyPropagateArrays::FindStoreInstruction( return store_inst; } -void CopyPropagateArrays::PropagateObject(ir::Instruction* var_inst, +void CopyPropagateArrays::PropagateObject(Instruction* var_inst, MemoryObject* source, - ir::Instruction* insertion_point) { + Instruction* insertion_point) { assert(var_inst->opcode() == SpvOpVariable && "This function propagates variables."); - ir::Instruction* new_access_chain = - BuildNewAccessChain(insertion_point, source); + Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source); context()->KillNamesAndDecorates(var_inst); UpdateUses(var_inst, new_access_chain); } -ir::Instruction* CopyPropagateArrays::BuildNewAccessChain( - ir::Instruction* insertion_point, +Instruction* CopyPropagateArrays::BuildNewAccessChain( + Instruction* insertion_point, CopyPropagateArrays::MemoryObject* source) const { - InstructionBuilder builder(context(), insertion_point, - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping); + InstructionBuilder builder( + context(), insertion_point, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); if (source->AccessChain().size() == 0) { return source->GetVariable(); @@ -143,34 +144,33 @@ ir::Instruction* CopyPropagateArrays::BuildNewAccessChain( source->AccessChain()); } -bool CopyPropagateArrays::HasNoStores(ir::Instruction* ptr_inst) { - return get_def_use_mgr()->WhileEachUser( - ptr_inst, [this](ir::Instruction* use) { - if (use->opcode() == SpvOpLoad) { - return true; - } else if (use->opcode() == SpvOpAccessChain) { - return HasNoStores(use); - } else if (use->IsDecoration() || use->opcode() == SpvOpName) { - return true; - } else if (use->opcode() == SpvOpStore) { - return false; - } else if (use->opcode() == SpvOpImageTexelPointer) { - return true; - } - // Some other instruction. Be conservative. - return false; - }); +bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) { + return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) { + if (use->opcode() == SpvOpLoad) { + return true; + } else if (use->opcode() == SpvOpAccessChain) { + return HasNoStores(use); + } else if (use->IsDecoration() || use->opcode() == SpvOpName) { + return true; + } else if (use->opcode() == SpvOpStore) { + return false; + } else if (use->opcode() == SpvOpImageTexelPointer) { + return true; + } + // Some other instruction. Be conservative. + return false; + }); } -bool CopyPropagateArrays::HasValidReferencesOnly(ir::Instruction* ptr_inst, - ir::Instruction* store_inst) { - ir::BasicBlock* store_block = context()->get_instr_block(store_inst); - opt::DominatorAnalysis* dominator_analysis = - context()->GetDominatorAnalysis(store_block->GetParent(), *cfg()); +bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst, + Instruction* store_inst) { + BasicBlock* store_block = context()->get_instr_block(store_inst); + DominatorAnalysis* dominator_analysis = + context()->GetDominatorAnalysis(store_block->GetParent()); return get_def_use_mgr()->WhileEachUser( ptr_inst, - [this, store_inst, dominator_analysis, ptr_inst](ir::Instruction* use) { + [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) { if (use->opcode() == SpvOpLoad || use->opcode() == SpvOpImageTexelPointer) { // TODO: If there are many load in the same BB as |store_inst| the @@ -194,7 +194,7 @@ bool CopyPropagateArrays::HasValidReferencesOnly(ir::Instruction* ptr_inst, std::unique_ptr CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) { - ir::Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result); + Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result); switch (result_inst->opcode()) { case SpvOpLoad: @@ -213,11 +213,11 @@ CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) { } std::unique_ptr -CopyPropagateArrays::BuildMemoryObjectFromLoad(ir::Instruction* load_inst) { +CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) { std::vector components_in_reverse; analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); - ir::Instruction* current_inst = def_use_mgr->GetDef( + Instruction* current_inst = def_use_mgr->GetDef( load_inst->GetSingleWordInOperand(kLoadPointerInOperand)); // Build the access chain for the memory object by collecting the indices used @@ -251,8 +251,7 @@ CopyPropagateArrays::BuildMemoryObjectFromLoad(ir::Instruction* load_inst) { } std::unique_ptr -CopyPropagateArrays::BuildMemoryObjectFromExtract( - ir::Instruction* extract_inst) { +CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) { assert(extract_inst->opcode() == SpvOpCompositeExtract && "Expecting an OpCompositeExtract instruction."); analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); @@ -269,7 +268,7 @@ CopyPropagateArrays::BuildMemoryObjectFromExtract( // Convert the indices in the extract instruction to a series of ids that // can be used by the |OpAccessChain| instruction. for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) { - uint32_t index = extract_inst->GetSingleWordInOperand(1); + uint32_t index = extract_inst->GetSingleWordInOperand(i); const analysis::Constant* index_const = const_mgr->GetConstant(uint32_type, {index}); components.push_back( @@ -283,7 +282,7 @@ CopyPropagateArrays::BuildMemoryObjectFromExtract( std::unique_ptr CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct( - ir::Instruction* conststruct_inst) { + Instruction* conststruct_inst) { assert(conststruct_inst->opcode() == SpvOpCompositeConstruct && "Expecting an OpCompositeConstruct instruction."); @@ -347,7 +346,7 @@ CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct( } std::unique_ptr -CopyPropagateArrays::BuildMemoryObjectFromInsert(ir::Instruction* insert_inst) { +CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) { assert(insert_inst->opcode() == SpvOpCompositeInsert && "Expecting an OpCompositeInsert instruction."); @@ -406,7 +405,7 @@ CopyPropagateArrays::BuildMemoryObjectFromInsert(ir::Instruction* insert_inst) { memory_object->GetParent(); - ir::Instruction* current_insert = + Instruction* current_insert = def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1)); for (uint32_t i = number_of_elements - 1; i > 0; --i) { if (current_insert->opcode() != SpvOpCompositeInsert) { @@ -468,7 +467,7 @@ bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) { return false; } -bool CopyPropagateArrays::CanUpdateUses(ir::Instruction* original_ptr_inst, +bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst, uint32_t type_id) { analysis::TypeManager* type_mgr = context()->get_type_mgr(); analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); @@ -487,7 +486,7 @@ bool CopyPropagateArrays::CanUpdateUses(ir::Instruction* original_ptr_inst, return def_use_mgr->WhileEachUse( original_ptr_inst, - [this, type_mgr, const_mgr, type](ir::Instruction* use, uint32_t) { + [this, type_mgr, const_mgr, type](Instruction* use, uint32_t) { switch (use->opcode()) { case SpvOpLoad: { analysis::Pointer* pointer_type = type->AsPointer(); @@ -519,8 +518,8 @@ bool CopyPropagateArrays::CanUpdateUses(ir::Instruction* original_ptr_inst, const analysis::Type* new_pointee_type = type_mgr->GetMemberType(pointee_type, access_chain); - opt::analysis::Pointer pointerTy(new_pointee_type, - pointer_type->storage_class()); + analysis::Pointer pointerTy(new_pointee_type, + pointer_type->storage_class()); uint32_t new_pointer_type_id = context()->get_type_mgr()->GetTypeInstruction(&pointerTy); @@ -560,8 +559,8 @@ bool CopyPropagateArrays::CanUpdateUses(ir::Instruction* original_ptr_inst, } }); } -void CopyPropagateArrays::UpdateUses(ir::Instruction* original_ptr_inst, - ir::Instruction* new_ptr_inst) { +void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst, + Instruction* new_ptr_inst) { // TODO (s-perron): Keep the def-use manager up to date. Not done now because // it can cause problems for the |ForEachUse| traversals. Can be use by // keeping a list of instructions that need updating, and then updating them @@ -571,14 +570,14 @@ void CopyPropagateArrays::UpdateUses(ir::Instruction* original_ptr_inst, analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); - std::vector > uses; + std::vector > uses; def_use_mgr->ForEachUse(original_ptr_inst, - [&uses](ir::Instruction* use, uint32_t index) { + [&uses](Instruction* use, uint32_t index) { uses.push_back({use, index}); }); for (auto pair : uses) { - ir::Instruction* use = pair.first; + Instruction* use = pair.first; uint32_t index = pair.second; analysis::Pointer* pointer_type = nullptr; switch (use->opcode()) { @@ -625,8 +624,8 @@ void CopyPropagateArrays::UpdateUses(ir::Instruction* original_ptr_inst, type_mgr->GetMemberType(pointee_type, access_chain); // Now build a pointer to the type of the member. - opt::analysis::Pointer new_pointer_type(new_pointee_type, - pointer_type->storage_class()); + analysis::Pointer new_pointer_type(new_pointee_type, + pointer_type->storage_class()); uint32_t new_pointer_type_id = context()->get_type_mgr()->GetTypeInstruction(&new_pointer_type); @@ -671,7 +670,7 @@ void CopyPropagateArrays::UpdateUses(ir::Instruction* original_ptr_inst, // decomposing the object into the base type, which must be the same, // and then rebuilding them. if (index == 1) { - ir::Instruction* target_pointer = def_use_mgr->GetDef( + Instruction* target_pointer = def_use_mgr->GetDef( use->GetSingleWordInOperand(kStorePointerInOperand)); pointer_type = type_mgr->GetType(target_pointer->type_id())->AsPointer(); @@ -701,9 +700,9 @@ void CopyPropagateArrays::UpdateUses(ir::Instruction* original_ptr_inst, } } -uint32_t CopyPropagateArrays::GenerateCopy( - ir::Instruction* object_inst, uint32_t new_type_id, - ir::Instruction* insertion_position) { +uint32_t CopyPropagateArrays::GenerateCopy(Instruction* object_inst, + uint32_t new_type_id, + Instruction* insertion_position) { analysis::TypeManager* type_mgr = context()->get_type_mgr(); analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); @@ -712,10 +711,9 @@ uint32_t CopyPropagateArrays::GenerateCopy( return object_inst->result_id(); } - opt::InstructionBuilder ir_builder( + InstructionBuilder ir_builder( context(), insertion_position, - ir::IRContext::kAnalysisInstrToBlockMapping | - ir::IRContext::kAnalysisDefUse); + IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDefUse); analysis::Type* original_type = type_mgr->GetType(original_type_id); analysis::Type* new_type = type_mgr->GetType(new_type_id); @@ -735,7 +733,7 @@ uint32_t CopyPropagateArrays::GenerateCopy( assert(length_const->AsIntConstant()); uint32_t array_length = length_const->AsIntConstant()->GetU32(); for (uint32_t i = 0; i < array_length; i++) { - ir::Instruction* extract = ir_builder.AddCompositeExtract( + Instruction* extract = ir_builder.AddCompositeExtract( original_element_type_id, object_inst->result_id(), {i}); element_ids.push_back( GenerateCopy(extract, new_element_type_id, insertion_position)); @@ -747,13 +745,13 @@ uint32_t CopyPropagateArrays::GenerateCopy( original_type->AsStruct()) { analysis::Struct* new_struct_type = new_type->AsStruct(); - const std::vector& original_types = + const std::vector& original_types = original_struct_type->element_types(); - const std::vector& new_types = + const std::vector& new_types = new_struct_type->element_types(); std::vector element_ids; for (uint32_t i = 0; i < original_types.size(); i++) { - ir::Instruction* extract = ir_builder.AddCompositeExtract( + Instruction* extract = ir_builder.AddCompositeExtract( type_mgr->GetId(original_types[i]), object_inst->result_id(), {i}); element_ids.push_back(GenerateCopy(extract, type_mgr->GetId(new_types[i]), insertion_position)); @@ -777,7 +775,7 @@ void CopyPropagateArrays::MemoryObject::GetMember( } uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() { - ir::IRContext* context = variable_inst_->context(); + IRContext* context = variable_inst_->context(); analysis::TypeManager* type_mgr = context->get_type_mgr(); const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id()); @@ -804,7 +802,7 @@ uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() { } template -CopyPropagateArrays::MemoryObject::MemoryObject(ir::Instruction* var_inst, +CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst, iterator begin, iterator end) : variable_inst_(var_inst), access_chain_(begin, end) {} diff --git a/3rdparty/spirv-tools/source/opt/copy_prop_arrays.h b/3rdparty/spirv-tools/source/opt/copy_prop_arrays.h index db6156a2b..abc07165f 100644 --- a/3rdparty/spirv-tools/source/opt/copy_prop_arrays.h +++ b/3rdparty/spirv-tools/source/opt/copy_prop_arrays.h @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_COPY_PROP_H_ -#define LIBSPIRV_OPT_COPY_PROP_H_ +#ifndef SOURCE_OPT_COPY_PROP_ARRAYS_H_ +#define SOURCE_OPT_COPY_PROP_ARRAYS_H_ -#include "mem_pass.h" +#include +#include + +#include "source/opt/mem_pass.h" namespace spvtools { namespace opt { @@ -38,15 +41,13 @@ namespace opt { class CopyPropagateArrays : public MemPass { public: const char* name() const override { return "copy-propagate-arrays"; } - Status Process(ir::IRContext*) override; + Status Process() override; - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse | ir::IRContext::kAnalysisCFG | - ir::IRContext::kAnalysisInstrToBlockMapping | - ir::IRContext::kAnalysisLoopAnalysis | - ir::IRContext::kAnalysisDecorations | - ir::IRContext::kAnalysisDominatorAnalysis | - ir::IRContext::kAnalysisNameMap; + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisCFG | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisDecorations | + IRContext::kAnalysisDominatorAnalysis | IRContext::kAnalysisNameMap; } private: @@ -62,7 +63,7 @@ class CopyPropagateArrays : public MemPass { // are interpreted the same way they would be in an |OpAccessChain| // instruction. template - MemoryObject(ir::Instruction* var_inst, iterator begin, iterator end); + MemoryObject(Instruction* var_inst, iterator begin, iterator end); // Change |this| to now point to the member identified by |access_chain| // (starting from the current member). The elements in |access_chain| are @@ -87,7 +88,7 @@ class CopyPropagateArrays : public MemPass { uint32_t GetNumberOfMembers(); // Returns the owning variable that the memory object is contained in. - ir::Instruction* GetVariable() const { return variable_inst_; } + Instruction* GetVariable() const { return variable_inst_; } // Returns a vector of integers that can be used to access the specific // member that |this| represents starting from the owning variable. These @@ -127,7 +128,7 @@ class CopyPropagateArrays : public MemPass { private: // The variable that owns this memory object. - ir::Instruction* variable_inst_; + Instruction* variable_inst_; // The access chain to reach the particular member the memory object // represents. It should be interpreted the same way the indices in an @@ -142,18 +143,17 @@ class CopyPropagateArrays : public MemPass { // and only identifies very simple cases. If no such memory object can be // found, the return value is |nullptr|. std::unique_ptr FindSourceObjectIfPossible( - ir::Instruction* var_inst, ir::Instruction* store_inst); + Instruction* var_inst, Instruction* store_inst); // Replaces all loads of |var_inst| with a load from |source| instead. // |insertion_pos| is a position where it is possible to construct the // address of |source| and also dominates all of the loads of |var_inst|. - void PropagateObject(ir::Instruction* var_inst, MemoryObject* source, - ir::Instruction* insertion_pos); + void PropagateObject(Instruction* var_inst, MemoryObject* source, + Instruction* insertion_pos); // Returns true if all of the references to |ptr_inst| can be rewritten and // are dominated by |store_inst|. - bool HasValidReferencesOnly(ir::Instruction* ptr_inst, - ir::Instruction* store_inst); + bool HasValidReferencesOnly(Instruction* ptr_inst, Instruction* store_inst); // Returns a memory object that at one time was equivalent to the value in // |result|. If no such memory object exists, the return value is |nullptr|. @@ -163,21 +163,21 @@ class CopyPropagateArrays : public MemPass { // object cannot be identified, the return value is |nullptr|. The opcode of // |load_inst| must be |OpLoad|. std::unique_ptr BuildMemoryObjectFromLoad( - ir::Instruction* load_inst); + Instruction* load_inst); // Returns the memory object that at some point was equivalent to the result // of |extract_inst|. If a memory object cannot be identified, the return // value is |nullptr|. The opcode of |extract_inst| must be // |OpCompositeExtract|. std::unique_ptr BuildMemoryObjectFromExtract( - ir::Instruction* extract_inst); + Instruction* extract_inst); // Returns the memory object that at some point was equivalent to the result // of |construct_inst|. If a memory object cannot be identified, the return // value is |nullptr|. The opcode of |constuct_inst| must be // |OpCompositeConstruct|. std::unique_ptr BuildMemoryObjectFromCompositeConstruct( - ir::Instruction* conststruct_inst); + Instruction* conststruct_inst); // Returns the memory object that at some point was equivalent to the result // of |insert_inst|. If a memory object cannot be identified, the return @@ -186,46 +186,46 @@ class CopyPropagateArrays : public MemPass { // |OpCompositeInsert| instructions that insert the elements one at a time in // order from beginning to end. std::unique_ptr BuildMemoryObjectFromInsert( - ir::Instruction* insert_inst); + Instruction* insert_inst); // Return true if |type_id| is a pointer type whose pointee type is an array. bool IsPointerToArrayType(uint32_t type_id); // Returns true of there are not stores using |ptr_inst| or something derived // from it. - bool HasNoStores(ir::Instruction* ptr_inst); + bool HasNoStores(Instruction* ptr_inst); // Creates an |OpAccessChain| instruction whose result is a pointer the memory // represented by |source|. The new instruction will be placed before // |insertion_point|. |insertion_point| must be part of a function. Returns // the new instruction. - ir::Instruction* BuildNewAccessChain(ir::Instruction* insertion_point, - MemoryObject* source) const; + Instruction* BuildNewAccessChain(Instruction* insertion_point, + MemoryObject* source) const; // Rewrites all uses of |original_ptr| to use |new_pointer_inst| updating // types of other instructions as needed. This function should not be called // if |CanUpdateUses(original_ptr_inst, new_pointer_inst->type_id())| returns // false. - void UpdateUses(ir::Instruction* original_ptr_inst, - ir::Instruction* new_pointer_inst); + void UpdateUses(Instruction* original_ptr_inst, + Instruction* new_pointer_inst); // Return true if |UpdateUses| is able to change all of the uses of // |original_ptr_inst| to |type_id| and still have valid code. - bool CanUpdateUses(ir::Instruction* original_ptr_inst, uint32_t type_id); + bool CanUpdateUses(Instruction* original_ptr_inst, uint32_t type_id); // Returns the id whose value is the same as |object_to_copy| except its type // is |new_type_id|. Any instructions need to generate this value will be // inserted before |insertion_position|. - uint32_t GenerateCopy(ir::Instruction* object_to_copy, uint32_t new_type_id, - ir::Instruction* insertion_position); + uint32_t GenerateCopy(Instruction* object_to_copy, uint32_t new_type_id, + Instruction* insertion_position); // Returns a store to |var_inst| that writes to the entire variable, and is // the only store that does so. Note it does not look through OpAccessChain // instruction, so partial stores are not considered. - ir::Instruction* FindStoreInstruction(const ir::Instruction* var_inst) const; + Instruction* FindStoreInstruction(const Instruction* var_inst) const; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_COPY_PROP_H_ +#endif // SOURCE_OPT_COPY_PROP_ARRAYS_H_ diff --git a/3rdparty/spirv-tools/source/opt/dead_branch_elim_pass.cpp b/3rdparty/spirv-tools/source/opt/dead_branch_elim_pass.cpp index 5dfe42e61..b147ef74c 100644 --- a/3rdparty/spirv-tools/source/opt/dead_branch_elim_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/dead_branch_elim_pass.cpp @@ -15,12 +15,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "dead_branch_elim_pass.h" +#include "source/opt/dead_branch_elim_pass.h" -#include "cfa.h" -#include "ir_context.h" -#include "iterator.h" -#include "make_unique.h" +#include +#include +#include + +#include "source/cfa.h" +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" +#include "source/util/make_unique.h" namespace spvtools { namespace opt { @@ -34,7 +38,7 @@ const uint32_t kBranchCondFalseLabIdInIdx = 2; bool DeadBranchElimPass::GetConstCondition(uint32_t condId, bool* condVal) { bool condIsConst; - ir::Instruction* cInst = get_def_use_mgr()->GetDef(condId); + Instruction* cInst = get_def_use_mgr()->GetDef(condId); switch (cInst->opcode()) { case SpvOpConstantFalse: { *condVal = false; @@ -56,9 +60,9 @@ bool DeadBranchElimPass::GetConstCondition(uint32_t condId, bool* condVal) { } bool DeadBranchElimPass::GetConstInteger(uint32_t selId, uint32_t* selVal) { - ir::Instruction* sInst = get_def_use_mgr()->GetDef(selId); + Instruction* sInst = get_def_use_mgr()->GetDef(selId); uint32_t typeId = sInst->type_id(); - ir::Instruction* typeInst = get_def_use_mgr()->GetDef(typeId); + Instruction* typeInst = get_def_use_mgr()->GetDef(typeId); if (!typeInst || (typeInst->opcode() != SpvOpTypeInt)) return false; // TODO(greg-lunarg): Support non-32 bit ints if (typeInst->GetSingleWordInOperand(0) != 32) return false; @@ -72,27 +76,28 @@ bool DeadBranchElimPass::GetConstInteger(uint32_t selId, uint32_t* selVal) { return false; } -void DeadBranchElimPass::AddBranch(uint32_t labelId, ir::BasicBlock* bp) { +void DeadBranchElimPass::AddBranch(uint32_t labelId, BasicBlock* bp) { assert(get_def_use_mgr()->GetDef(labelId) != nullptr); - std::unique_ptr newBranch(new ir::Instruction( - context(), SpvOpBranch, 0, 0, - {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}})); - get_def_use_mgr()->AnalyzeInstDefUse(&*newBranch); + std::unique_ptr newBranch( + new Instruction(context(), SpvOpBranch, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}})); + context()->AnalyzeDefUse(&*newBranch); + context()->set_instr_block(&*newBranch, bp); bp->AddInstruction(std::move(newBranch)); } -ir::BasicBlock* DeadBranchElimPass::GetParentBlock(uint32_t id) { +BasicBlock* DeadBranchElimPass::GetParentBlock(uint32_t id) { return context()->get_instr_block(get_def_use_mgr()->GetDef(id)); } bool DeadBranchElimPass::MarkLiveBlocks( - ir::Function* func, std::unordered_set* live_blocks) { - std::unordered_set continues; - std::vector stack; + Function* func, std::unordered_set* live_blocks) { + std::unordered_set continues; + std::vector stack; stack.push_back(&*func->begin()); bool modified = false; while (!stack.empty()) { - ir::BasicBlock* block = stack.back(); + BasicBlock* block = stack.back(); stack.pop_back(); // Live blocks doubles as visited set. @@ -101,7 +106,7 @@ bool DeadBranchElimPass::MarkLiveBlocks( uint32_t cont_id = block->ContinueBlockIdIfAny(); if (cont_id != 0) continues.insert(GetParentBlock(cont_id)); - ir::Instruction* terminator = block->terminator(); + Instruction* terminator = block->terminator(); uint32_t live_lab_id = 0; // Check if the terminator has a single valid successor. if (terminator->opcode() == SpvOpBranchConditional) { @@ -152,9 +157,18 @@ bool DeadBranchElimPass::MarkLiveBlocks( // Remove the merge instruction if it is a selection merge. AddBranch(live_lab_id, block); context()->KillInst(terminator); - ir::Instruction* mergeInst = block->GetMergeInst(); + Instruction* mergeInst = block->GetMergeInst(); if (mergeInst && mergeInst->opcode() == SpvOpSelectionMerge) { - context()->KillInst(mergeInst); + Instruction* first_break = FindFirstExitFromSelectionMerge( + live_lab_id, mergeInst->GetSingleWordInOperand(0)); + if (first_break == nullptr) { + context()->KillInst(mergeInst); + } else { + mergeInst->RemoveFromList(); + first_break->InsertBefore(std::unique_ptr(mergeInst)); + context()->set_instr_block(mergeInst, + context()->get_instr_block(first_break)); + } } stack.push_back(GetParentBlock(live_lab_id)); } else { @@ -170,18 +184,17 @@ bool DeadBranchElimPass::MarkLiveBlocks( } void DeadBranchElimPass::MarkUnreachableStructuredTargets( - const std::unordered_set& live_blocks, - std::unordered_set* unreachable_merges, - std::unordered_map* - unreachable_continues) { + const std::unordered_set& live_blocks, + std::unordered_set* unreachable_merges, + std::unordered_map* unreachable_continues) { for (auto block : live_blocks) { if (auto merge_id = block->MergeBlockIdIfAny()) { - ir::BasicBlock* merge_block = GetParentBlock(merge_id); + BasicBlock* merge_block = GetParentBlock(merge_id); if (!live_blocks.count(merge_block)) { unreachable_merges->insert(merge_block); } if (auto cont_id = block->ContinueBlockIdIfAny()) { - ir::BasicBlock* cont_block = GetParentBlock(cont_id); + BasicBlock* cont_block = GetParentBlock(cont_id); if (!live_blocks.count(cont_block)) { (*unreachable_continues)[cont_block] = block; } @@ -191,9 +204,8 @@ void DeadBranchElimPass::MarkUnreachableStructuredTargets( } bool DeadBranchElimPass::FixPhiNodesInLiveBlocks( - ir::Function* func, const std::unordered_set& live_blocks, - const std::unordered_map& - unreachable_continues) { + Function* func, const std::unordered_set& live_blocks, + const std::unordered_map& unreachable_continues) { bool modified = false; for (auto& block : *func) { if (live_blocks.count(&block)) { @@ -204,8 +216,8 @@ bool DeadBranchElimPass::FixPhiNodesInLiveBlocks( bool changed = false; bool backedge_added = false; - ir::Instruction* inst = &*iter; - std::vector operands; + Instruction* inst = &*iter; + std::vector operands; // Build a complete set of operands (not just input operands). Start // with type and result id operands. operands.push_back(inst->GetOperand(0u)); @@ -218,7 +230,7 @@ bool DeadBranchElimPass::FixPhiNodesInLiveBlocks( // However, if there is only one other incoming edge, the OpPhi can be // eliminated. for (uint32_t i = 1; i < inst->NumInOperands(); i += 2) { - ir::BasicBlock* inc = GetParentBlock(inst->GetSingleWordInOperand(i)); + BasicBlock* inc = GetParentBlock(inst->GetSingleWordInOperand(i)); auto cont_iter = unreachable_continues.find(inc); if (cont_iter != unreachable_continues.end() && cont_iter->second == &block && inst->NumInOperands() > 4) { @@ -301,10 +313,9 @@ bool DeadBranchElimPass::FixPhiNodesInLiveBlocks( } bool DeadBranchElimPass::EraseDeadBlocks( - ir::Function* func, const std::unordered_set& live_blocks, - const std::unordered_set& unreachable_merges, - const std::unordered_map& - unreachable_continues) { + Function* func, const std::unordered_set& live_blocks, + const std::unordered_set& unreachable_merges, + const std::unordered_map& unreachable_continues) { bool modified = false; for (auto ebi = func->begin(); ebi != func->end();) { if (unreachable_merges.count(&*ebi)) { @@ -314,8 +325,9 @@ bool DeadBranchElimPass::EraseDeadBlocks( KillAllInsts(&*ebi, false); // Add unreachable terminator. ebi->AddInstruction( - MakeUnique(context(), SpvOpUnreachable, 0, 0, - std::initializer_list{})); + MakeUnique(context(), SpvOpUnreachable, 0, 0, + std::initializer_list{})); + context()->set_instr_block(&*ebi->tail(), &*ebi); modified = true; } ++ebi; @@ -328,11 +340,11 @@ bool DeadBranchElimPass::EraseDeadBlocks( KillAllInsts(&*ebi, false); // Add unconditional branch to header. assert(unreachable_continues.count(&*ebi)); - ebi->AddInstruction( - MakeUnique(context(), SpvOpBranch, 0, 0, - std::initializer_list{ - {SPV_OPERAND_TYPE_ID, {cont_id}}})); + ebi->AddInstruction(MakeUnique( + context(), SpvOpBranch, 0, 0, + std::initializer_list{{SPV_OPERAND_TYPE_ID, {cont_id}}})); get_def_use_mgr()->AnalyzeInstUse(&*ebi->tail()); + context()->set_instr_block(&*ebi->tail(), &*ebi); modified = true; } ++ebi; @@ -349,13 +361,13 @@ bool DeadBranchElimPass::EraseDeadBlocks( return modified; } -bool DeadBranchElimPass::EliminateDeadBranches(ir::Function* func) { +bool DeadBranchElimPass::EliminateDeadBranches(Function* func) { bool modified = false; - std::unordered_set live_blocks; + std::unordered_set live_blocks; modified |= MarkLiveBlocks(func, &live_blocks); - std::unordered_set unreachable_merges; - std::unordered_map unreachable_continues; + std::unordered_set unreachable_merges; + std::unordered_map unreachable_continues; MarkUnreachableStructuredTargets(live_blocks, &unreachable_merges, &unreachable_continues); modified |= FixPhiNodesInLiveBlocks(func, live_blocks, unreachable_continues); @@ -365,29 +377,90 @@ bool DeadBranchElimPass::EliminateDeadBranches(ir::Function* func) { return modified; } -void DeadBranchElimPass::Initialize(ir::IRContext* c) { - InitializeProcessing(c); +void DeadBranchElimPass::FixBlockOrder() { + context()->BuildInvalidAnalyses(IRContext::kAnalysisCFG | + IRContext::kAnalysisDominatorAnalysis); + // Reorders blocks according to DFS of dominator tree. + ProcessFunction reorder_dominators = [this](Function* function) { + DominatorAnalysis* dominators = context()->GetDominatorAnalysis(function); + std::vector blocks; + for (auto iter = dominators->GetDomTree().begin(); + iter != dominators->GetDomTree().end(); ++iter) { + if (iter->id() != 0) { + blocks.push_back(iter->bb_); + } + } + for (uint32_t i = 1; i < blocks.size(); ++i) { + function->MoveBasicBlockToAfter(blocks[i]->id(), blocks[i - 1]); + } + return true; + }; + + // Reorders blocks according to structured order. + ProcessFunction reorder_structured = [this](Function* function) { + std::list order; + context()->cfg()->ComputeStructuredOrder(function, &*function->begin(), + &order); + std::vector blocks; + for (auto block : order) { + blocks.push_back(block); + } + for (uint32_t i = 1; i < blocks.size(); ++i) { + function->MoveBasicBlockToAfter(blocks[i]->id(), blocks[i - 1]); + } + return true; + }; + + // Structured order is more intuitive so use it where possible. + if (context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) { + ProcessReachableCallTree(reorder_structured, context()); + } else { + ProcessReachableCallTree(reorder_dominators, context()); + } } -Pass::Status DeadBranchElimPass::ProcessImpl() { +Pass::Status DeadBranchElimPass::Process() { // Do not process if module contains OpGroupDecorate. Additional // support required in KillNamesAndDecorates(). // TODO(greg-lunarg): Add support for OpGroupDecorate for (auto& ai : get_module()->annotations()) if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange; // Process all entry point functions - ProcessFunction pfn = [this](ir::Function* fp) { + ProcessFunction pfn = [this](Function* fp) { return EliminateDeadBranches(fp); }; bool modified = ProcessReachableCallTree(pfn, context()); + if (modified) FixBlockOrder(); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } -DeadBranchElimPass::DeadBranchElimPass() {} - -Pass::Status DeadBranchElimPass::Process(ir::IRContext* module) { - Initialize(module); - return ProcessImpl(); +Instruction* DeadBranchElimPass::FindFirstExitFromSelectionMerge( + uint32_t start_block_id, uint32_t merge_block_id) { + // To find the "first" exit, we follow branches looking for a conditional + // branch that is not in a nested construct and is not the header of a new + // construct. We follow the control flow from |start_block_id| to find the + // first one. + while (start_block_id != merge_block_id) { + BasicBlock* start_block = context()->get_instr_block(start_block_id); + Instruction* branch = start_block->terminator(); + uint32_t next_block_id = 0; + switch (branch->opcode()) { + case SpvOpBranchConditional: + case SpvOpSwitch: + next_block_id = start_block->MergeBlockIdIfAny(); + if (next_block_id == 0) { + return branch; + } + break; + case SpvOpBranch: + next_block_id = branch->GetSingleWordInOperand(0); + break; + default: + return nullptr; + } + start_block_id = next_block_id; + } + return nullptr; } } // namespace opt diff --git a/3rdparty/spirv-tools/source/opt/dead_branch_elim_pass.h b/3rdparty/spirv-tools/source/opt/dead_branch_elim_pass.h index 62ec582ca..f8b441207 100644 --- a/3rdparty/spirv-tools/source/opt/dead_branch_elim_pass.h +++ b/3rdparty/spirv-tools/source/opt/dead_branch_elim_pass.h @@ -14,8 +14,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_DEAD_BRANCH_ELIM_PASS_H_ -#define LIBSPIRV_OPT_DEAD_BRANCH_ELIM_PASS_H_ +#ifndef SOURCE_OPT_DEAD_BRANCH_ELIM_PASS_H_ +#define SOURCE_OPT_DEAD_BRANCH_ELIM_PASS_H_ #include #include @@ -23,26 +23,28 @@ #include #include #include +#include -#include "basic_block.h" -#include "def_use_manager.h" -#include "mem_pass.h" -#include "module.h" +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { // See optimizer.hpp for documentation. class DeadBranchElimPass : public MemPass { - using cbb_ptr = const ir::BasicBlock*; + using cbb_ptr = const BasicBlock*; public: - DeadBranchElimPass(); - const char* name() const override { return "eliminate-dead-branches"; } - Status Process(ir::IRContext* context) override; + DeadBranchElimPass() = default; - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse; + const char* name() const override { return "eliminate-dead-branches"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping; } private: @@ -55,7 +57,7 @@ class DeadBranchElimPass : public MemPass { bool GetConstInteger(uint32_t valId, uint32_t* value); // Add branch to |labelId| to end of block |bp|. - void AddBranch(uint32_t labelId, ir::BasicBlock* bp); + void AddBranch(uint32_t labelId, BasicBlock* bp); // For function |func|, look for BranchConditionals with constant condition // and convert to a Branch to the indicated label. Delete resulting dead @@ -63,21 +65,21 @@ class DeadBranchElimPass : public MemPass { // invalid control flow. // TODO(greg-lunarg): Remove remaining constant conditional branches and dead // blocks. - bool EliminateDeadBranches(ir::Function* func); + bool EliminateDeadBranches(Function* func); // Returns the basic block containing |id|. // Note: this pass only requires correct instruction block mappings for the // input. This pass does not preserve the block mapping, so it is not kept // up-to-date during processing. - ir::BasicBlock* GetParentBlock(uint32_t id); + BasicBlock* GetParentBlock(uint32_t id); // Marks live blocks reachable from the entry of |func|. Simplifies constant // branches and switches as it proceeds, to limit the number of live blocks. // It is careful not to eliminate backedges even if they are dead, but the // header is live. Likewise, unreachable merge blocks named in live merge // instruction must be retained (though they may be clobbered). - bool MarkLiveBlocks(ir::Function* func, - std::unordered_set* live_blocks); + bool MarkLiveBlocks(Function* func, + std::unordered_set* live_blocks); // Checks for unreachable merge and continue blocks with live headers; those // blocks must be retained. Continues are tracked separately so that a live @@ -87,10 +89,9 @@ class DeadBranchElimPass : public MemPass { // |unreachable_continues| maps the id of an unreachable continue target to // the header block that declares it. void MarkUnreachableStructuredTargets( - const std::unordered_set& live_blocks, - std::unordered_set* unreachable_merges, - std::unordered_map* - unreachable_continues); + const std::unordered_set& live_blocks, + std::unordered_set* unreachable_merges, + std::unordered_map* unreachable_continues); // Fix phis in reachable blocks so that only live (or unremovable) incoming // edges are present. If the block now only has a single live incoming edge, @@ -105,9 +106,8 @@ class DeadBranchElimPass : public MemPass { // |unreachable_continues| maps continue targets that cannot be reached to // merge instruction that declares them. bool FixPhiNodesInLiveBlocks( - ir::Function* func, - const std::unordered_set& live_blocks, - const std::unordered_map& + Function* func, const std::unordered_set& live_blocks, + const std::unordered_map& unreachable_continues); // Erases dead blocks. Any block captured in |unreachable_merges| or @@ -122,17 +122,28 @@ class DeadBranchElimPass : public MemPass { // |unreachable_continues| maps continue targets that cannot be reached to // corresponding header block that declares them. bool EraseDeadBlocks( - ir::Function* func, - const std::unordered_set& live_blocks, - const std::unordered_set& unreachable_merges, - const std::unordered_map& + Function* func, const std::unordered_set& live_blocks, + const std::unordered_set& unreachable_merges, + const std::unordered_map& unreachable_continues); - void Initialize(ir::IRContext* c); - Pass::Status ProcessImpl(); + // Reorders blocks in reachable functions so that they satisfy dominator + // block ordering rules. + void FixBlockOrder(); + + // Return the first branch instruction that is a conditional branch to + // |merge_block_id|. Returns |nullptr| if not such branch exists. If there are + // multiple such branches, the first one is the one that would be executed + // first when running the code. That is, the one that dominates all of the + // others. + // + // |start_block_id| must be a block whose innermost containing merge construct + // has |merge_block_id| as the merge block. + Instruction* FindFirstExitFromSelectionMerge(uint32_t start_block_id, + uint32_t merge_block_id); }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_DEAD_BRANCH_ELIM_PASS_H_ +#endif // SOURCE_OPT_DEAD_BRANCH_ELIM_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/dead_insert_elim_pass.cpp b/3rdparty/spirv-tools/source/opt/dead_insert_elim_pass.cpp index 55f4efe7b..b42588ff7 100644 --- a/3rdparty/spirv-tools/source/opt/dead_insert_elim_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/dead_insert_elim_pass.cpp @@ -14,15 +14,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "dead_insert_elim_pass.h" +#include "source/opt/dead_insert_elim_pass.h" -#include "composite.h" -#include "ir_context.h" -#include "iterator.h" +#include "source/opt/composite.h" +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" #include "spirv/1.2/GLSL.std.450.h" -#include - namespace spvtools { namespace opt { @@ -38,7 +36,7 @@ const uint32_t kInsertCompositeIdInIdx = 1; } // anonymous namespace -uint32_t DeadInsertElimPass::NumComponents(ir::Instruction* typeInst) { +uint32_t DeadInsertElimPass::NumComponents(Instruction* typeInst) { switch (typeInst->opcode()) { case SpvOpTypeVector: { return typeInst->GetSingleWordInOperand(kTypeVectorCountInIdx); @@ -49,10 +47,10 @@ uint32_t DeadInsertElimPass::NumComponents(ir::Instruction* typeInst) { case SpvOpTypeArray: { uint32_t lenId = typeInst->GetSingleWordInOperand(kTypeArrayLengthIdInIdx); - ir::Instruction* lenInst = get_def_use_mgr()->GetDef(lenId); + Instruction* lenInst = get_def_use_mgr()->GetDef(lenId); if (lenInst->opcode() != SpvOpConstant) return 0; uint32_t lenTypeId = lenInst->type_id(); - ir::Instruction* lenTypeInst = get_def_use_mgr()->GetDef(lenTypeId); + Instruction* lenTypeInst = get_def_use_mgr()->GetDef(lenTypeId); // TODO(greg-lunarg): Support non-32-bit array length if (lenTypeInst->GetSingleWordInOperand(kTypeIntWidthInIdx) != 32) return 0; @@ -65,11 +63,11 @@ uint32_t DeadInsertElimPass::NumComponents(ir::Instruction* typeInst) { } } -void DeadInsertElimPass::MarkInsertChain(ir::Instruction* insertChain, - std::vector* pExtIndices, - uint32_t extOffset) { +void DeadInsertElimPass::MarkInsertChain( + Instruction* insertChain, std::vector* pExtIndices, + uint32_t extOffset, std::unordered_set* visited_phis) { // Not currently optimizing array inserts. - ir::Instruction* typeInst = get_def_use_mgr()->GetDef(insertChain->type_id()); + Instruction* typeInst = get_def_use_mgr()->GetDef(insertChain->type_id()); if (typeInst->opcode() == SpvOpTypeArray) return; // Insert chains are only composed of inserts and phis if (insertChain->opcode() != SpvOpCompositeInsert && @@ -84,12 +82,13 @@ void DeadInsertElimPass::MarkInsertChain(ir::Instruction* insertChain, for (uint32_t i = 0; i < cnum; i++) { extIndices.clear(); extIndices.push_back(i); - MarkInsertChain(insertChain, &extIndices, 0); + std::unordered_set sub_visited_phis; + MarkInsertChain(insertChain, &extIndices, 0, &sub_visited_phis); } return; } } - ir::Instruction* insInst = insertChain; + Instruction* insInst = insertChain; while (insInst->opcode() == SpvOpCompositeInsert) { // If no extract indices, mark insert and inserted object (which might // also be an insert chain) and continue up the chain though the input @@ -101,33 +100,37 @@ void DeadInsertElimPass::MarkInsertChain(ir::Instruction* insertChain, if (pExtIndices == nullptr) { liveInserts_.insert(insInst->result_id()); uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); - MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0); - } + std::unordered_set obj_visited_phis; + MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0, + &obj_visited_phis); // If extract indices match insert, we are done. Mark insert and // inserted object. - else if (ExtInsMatch(*pExtIndices, insInst, extOffset)) { + } else if (ExtInsMatch(*pExtIndices, insInst, extOffset)) { liveInserts_.insert(insInst->result_id()); uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); - MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0); + std::unordered_set obj_visited_phis; + MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0, + &obj_visited_phis); break; - } // If non-matching intersection, mark insert - else if (ExtInsConflict(*pExtIndices, insInst, extOffset)) { + } else if (ExtInsConflict(*pExtIndices, insInst, extOffset)) { liveInserts_.insert(insInst->result_id()); // If more extract indices than insert, we are done. Use remaining // extract indices to mark inserted object. uint32_t numInsertIndices = insInst->NumInOperands() - 2; if (pExtIndices->size() - extOffset > numInsertIndices) { uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); + std::unordered_set obj_visited_phis; MarkInsertChain(get_def_use_mgr()->GetDef(objId), pExtIndices, - extOffset + numInsertIndices); + extOffset + numInsertIndices, &obj_visited_phis); break; - } // If fewer extract indices than insert, also mark inserted object and // continue up chain. - else { + } else { uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); - MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0); + std::unordered_set obj_visited_phis; + MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0, + &obj_visited_phis); } } // Get next insert in chain @@ -139,14 +142,8 @@ void DeadInsertElimPass::MarkInsertChain(ir::Instruction* insertChain, if (insInst->opcode() != SpvOpPhi) return; // Mark phi visited to prevent potential infinite loop. If phi is already // visited, return to avoid infinite loop. - auto iter = visitedPhis_.find(insInst->result_id()); - if (iter == visitedPhis_.end()) { - iter = visitedPhis_.emplace(insInst->result_id(), true).first; - } else if (iter->second) { - return; - } else { - iter->second = true; - } + if (visited_phis->count(insInst->result_id()) != 0) return; + visited_phis->insert(insInst->result_id()); // Phis may have duplicate inputs values for different edges, prune incoming // ids lists before recursing. @@ -157,15 +154,12 @@ void DeadInsertElimPass::MarkInsertChain(ir::Instruction* insertChain, std::sort(ids.begin(), ids.end()); auto new_end = std::unique(ids.begin(), ids.end()); for (auto id_iter = ids.begin(); id_iter != new_end; ++id_iter) { - ir::Instruction* pi = get_def_use_mgr()->GetDef(*id_iter); - MarkInsertChain(pi, pExtIndices, extOffset); + Instruction* pi = get_def_use_mgr()->GetDef(*id_iter); + MarkInsertChain(pi, pExtIndices, extOffset, visited_phis); } - - // Unmark phi when done visiting. - iter->second = false; } -bool DeadInsertElimPass::EliminateDeadInserts(ir::Function* func) { +bool DeadInsertElimPass::EliminateDeadInserts(Function* func) { bool modified = false; bool lastmodified = true; // Each pass can delete dead instructions, thus potentially revealing @@ -177,7 +171,7 @@ bool DeadInsertElimPass::EliminateDeadInserts(ir::Function* func) { return modified; } -bool DeadInsertElimPass::EliminateDeadInsertsOnePass(ir::Function* func) { +bool DeadInsertElimPass::EliminateDeadInsertsOnePass(Function* func) { bool modified = false; liveInserts_.clear(); visitedPhis_.clear(); @@ -186,7 +180,7 @@ bool DeadInsertElimPass::EliminateDeadInsertsOnePass(ir::Function* func) { for (auto ii = bi->begin(); ii != bi->end(); ++ii) { // Only process Inserts and composite Phis SpvOp op = ii->opcode(); - ir::Instruction* typeInst = get_def_use_mgr()->GetDef(ii->type_id()); + Instruction* typeInst = get_def_use_mgr()->GetDef(ii->type_id()); if (op != SpvOpCompositeInsert && (op != SpvOpPhi || !spvOpcodeIsComposite(typeInst->opcode()))) continue; @@ -201,7 +195,7 @@ bool DeadInsertElimPass::EliminateDeadInsertsOnePass(ir::Function* func) { } } const uint32_t id = ii->result_id(); - get_def_use_mgr()->ForEachUser(id, [&ii, this](ir::Instruction* user) { + get_def_use_mgr()->ForEachUser(id, [&ii, this](Instruction* user) { switch (user->opcode()) { case SpvOpCompositeInsert: case SpvOpPhi: @@ -216,18 +210,19 @@ bool DeadInsertElimPass::EliminateDeadInsertsOnePass(ir::Function* func) { ++icnt; }); // Mark all inserts in chain that intersect with extract - MarkInsertChain(&*ii, &extIndices, 0); + std::unordered_set visited_phis; + MarkInsertChain(&*ii, &extIndices, 0, &visited_phis); } break; default: { // Mark inserts in chain for all components - MarkInsertChain(&*ii, nullptr, 0); + MarkInsertChain(&*ii, nullptr, 0, nullptr); } break; } }); } } // Find and disconnect dead inserts - std::vector dead_instructions; + std::vector dead_instructions; for (auto bi = func->begin(); bi != func->end(); ++bi) { for (auto ii = bi->begin(); ii != bi->end(); ++ii) { if (ii->opcode() != SpvOpCompositeInsert) continue; @@ -242,9 +237,9 @@ bool DeadInsertElimPass::EliminateDeadInsertsOnePass(ir::Function* func) { } // DCE dead inserts while (!dead_instructions.empty()) { - ir::Instruction* inst = dead_instructions.back(); + Instruction* inst = dead_instructions.back(); dead_instructions.pop_back(); - DCEInst(inst, [&dead_instructions](ir::Instruction* other_inst) { + DCEInst(inst, [&dead_instructions](Instruction* other_inst) { auto i = std::find(dead_instructions.begin(), dead_instructions.end(), other_inst); if (i != dead_instructions.end()) { @@ -255,25 +250,14 @@ bool DeadInsertElimPass::EliminateDeadInsertsOnePass(ir::Function* func) { return modified; } -void DeadInsertElimPass::Initialize(ir::IRContext* c) { - InitializeProcessing(c); -} - -Pass::Status DeadInsertElimPass::ProcessImpl() { +Pass::Status DeadInsertElimPass::Process() { // Process all entry point functions. - ProcessFunction pfn = [this](ir::Function* fp) { + ProcessFunction pfn = [this](Function* fp) { return EliminateDeadInserts(fp); }; bool modified = ProcessEntryPointCallTree(pfn, get_module()); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } -DeadInsertElimPass::DeadInsertElimPass() {} - -Pass::Status DeadInsertElimPass::Process(ir::IRContext* c) { - Initialize(c); - return ProcessImpl(); -} - } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/dead_insert_elim_pass.h b/3rdparty/spirv-tools/source/opt/dead_insert_elim_pass.h index 97a725d9c..0b111d02c 100644 --- a/3rdparty/spirv-tools/source/opt/dead_insert_elim_pass.h +++ b/3rdparty/spirv-tools/source/opt/dead_insert_elim_pass.h @@ -14,20 +14,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_DEAD_INSERT_ELIM_PASS_H_ -#define LIBSPIRV_OPT_DEAD_INSERT_ELIM_PASS_H_ +#ifndef SOURCE_OPT_DEAD_INSERT_ELIM_PASS_H_ +#define SOURCE_OPT_DEAD_INSERT_ELIM_PASS_H_ #include #include #include #include #include +#include -#include "basic_block.h" -#include "def_use_manager.h" -#include "ir_context.h" -#include "mem_pass.h" -#include "module.h" +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { @@ -35,40 +36,46 @@ namespace opt { // See optimizer.hpp for documentation. class DeadInsertElimPass : public MemPass { public: - DeadInsertElimPass(); + DeadInsertElimPass() = default; + const char* name() const override { return "eliminate-dead-inserts"; } - Status Process(ir::IRContext*) override; + Status Process() override; + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap; + } private: // Return the number of subcomponents in the composite type |typeId|. // Return 0 if not a composite type or number of components is not a // 32-bit constant. - uint32_t NumComponents(ir::Instruction* typeInst); + uint32_t NumComponents(Instruction* typeInst); // Mark all inserts in instruction chain ending at |insertChain| with // indices that intersect with extract indices |extIndices| starting with // index at |extOffset|. Chains are composed solely of Inserts and Phis. // Mark all inserts in chain if |extIndices| is nullptr. - void MarkInsertChain(ir::Instruction* insertChain, - std::vector* extIndices, uint32_t extOffset); + void MarkInsertChain(Instruction* insertChain, + std::vector* extIndices, uint32_t extOffset, + std::unordered_set* visited_phis); // Perform EliminateDeadInsertsOnePass(|func|) until no modification is // made. Return true if modified. - bool EliminateDeadInserts(ir::Function* func); + bool EliminateDeadInserts(Function* func); // DCE all dead struct, matrix and vector inserts in |func|. An insert is // dead if the value it inserts is never used. Replace any reference to the // insert with its original composite. Return true if modified. Dead inserts // in dependence cycles are not currently eliminated. Dead inserts into // arrays are not currently eliminated. - bool EliminateDeadInsertsOnePass(ir::Function* func); + bool EliminateDeadInsertsOnePass(Function* func); // Return true if all extensions in this module are allowed by this pass. bool AllExtensionsSupported() const; - void Initialize(ir::IRContext* c); - Pass::Status ProcessImpl(); - // Live inserts std::unordered_set liveInserts_; @@ -79,4 +86,4 @@ class DeadInsertElimPass : public MemPass { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_DEAD_INSERT_ELIM_PASS_H_ +#endif // SOURCE_OPT_DEAD_INSERT_ELIM_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/dead_variable_elimination.cpp b/3rdparty/spirv-tools/source/opt/dead_variable_elimination.cpp index 1fec3a4f8..283710684 100644 --- a/3rdparty/spirv-tools/source/opt/dead_variable_elimination.cpp +++ b/3rdparty/spirv-tools/source/opt/dead_variable_elimination.cpp @@ -12,24 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "dead_variable_elimination.h" +#include "source/opt/dead_variable_elimination.h" -#include "ir_context.h" -#include "reflect.h" +#include + +#include "source/opt/ir_context.h" +#include "source/opt/reflect.h" namespace spvtools { namespace opt { // This optimization removes global variables that are not needed because they // are definitely not accessed. -Pass::Status DeadVariableElimination::Process(ir::IRContext* c) { +Pass::Status DeadVariableElimination::Process() { // The algorithm will compute the reference count for every global variable. // Anything with a reference count of 0 will then be deleted. For variables // that might have references that are not explicit in this context, we use - // the - // value kMustKeep as the reference count. - InitializeProcessing(c); - + // the value kMustKeep as the reference count. std::vector ids_to_remove; // Get the reference count for all of the global OpVariable instructions. @@ -45,7 +44,7 @@ Pass::Status DeadVariableElimination::Process(ir::IRContext* c) { // else, so we must keep the variable around. get_decoration_mgr()->ForEachDecoration( result_id, SpvDecorationLinkageAttributes, - [&count](const ir::Instruction& linkage_instruction) { + [&count](const Instruction& linkage_instruction) { uint32_t last_operand = linkage_instruction.NumOperands() - 1; if (linkage_instruction.GetSingleWordOperand(last_operand) == SpvLinkageTypeExport) { @@ -57,13 +56,11 @@ Pass::Status DeadVariableElimination::Process(ir::IRContext* c) { // If we don't have to keep the instruction for other reasons, then look // at the uses and count the number of real references. count = 0; - get_def_use_mgr()->ForEachUser( - result_id, [&count](ir::Instruction* user) { - if (!ir::IsAnnotationInst(user->opcode()) && - user->opcode() != SpvOpName) { - ++count; - } - }); + get_def_use_mgr()->ForEachUser(result_id, [&count](Instruction* user) { + if (!IsAnnotationInst(user->opcode()) && user->opcode() != SpvOpName) { + ++count; + } + }); } reference_count_[result_id] = count; if (count == 0) { @@ -83,14 +80,14 @@ Pass::Status DeadVariableElimination::Process(ir::IRContext* c) { } void DeadVariableElimination::DeleteVariable(uint32_t result_id) { - ir::Instruction* inst = get_def_use_mgr()->GetDef(result_id); + Instruction* inst = get_def_use_mgr()->GetDef(result_id); assert(inst->opcode() == SpvOpVariable && "Should not be trying to delete anything other than an OpVariable."); // Look for an initializer that references another variable. We need to know // if that variable can be deleted after the reference is removed. if (inst->NumOperands() == 4) { - ir::Instruction* initializer = + Instruction* initializer = get_def_use_mgr()->GetDef(inst->GetSingleWordOperand(3)); // TODO: Handle OpSpecConstantOP which might be defined in terms of other diff --git a/3rdparty/spirv-tools/source/opt/dead_variable_elimination.h b/3rdparty/spirv-tools/source/opt/dead_variable_elimination.h index f016e78b2..40a7bc025 100644 --- a/3rdparty/spirv-tools/source/opt/dead_variable_elimination.h +++ b/3rdparty/spirv-tools/source/opt/dead_variable_elimination.h @@ -12,25 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SPIRV_TOOLS_DEAD_VARIABLE_ELIMINATION_H -#define SPIRV_TOOLS_DEAD_VARIABLE_ELIMINATION_H +#ifndef SOURCE_OPT_DEAD_VARIABLE_ELIMINATION_H_ +#define SOURCE_OPT_DEAD_VARIABLE_ELIMINATION_H_ #include #include -#include "decoration_manager.h" -#include "mem_pass.h" +#include "source/opt/decoration_manager.h" +#include "source/opt/mem_pass.h" namespace spvtools { namespace opt { class DeadVariableElimination : public MemPass { public: - const char* name() const override { return "dead-variable-elimination"; } - Status Process(ir::IRContext* c) override; + const char* name() const override { return "eliminate-dead-variables"; } + Status Process() override; - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse; + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse; } private: @@ -52,4 +52,4 @@ class DeadVariableElimination : public MemPass { } // namespace opt } // namespace spvtools -#endif // SPIRV_TOOLS_DEAD_VARIABLE_ELIMINATION_H +#endif // SOURCE_OPT_DEAD_VARIABLE_ELIMINATION_H_ diff --git a/3rdparty/spirv-tools/source/opt/decoration_manager.cpp b/3rdparty/spirv-tools/source/opt/decoration_manager.cpp index f382d7846..82aa495c9 100644 --- a/3rdparty/spirv-tools/source/opt/decoration_manager.cpp +++ b/3rdparty/spirv-tools/source/opt/decoration_manager.cpp @@ -12,48 +12,50 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "decoration_manager.h" +#include "source/opt/decoration_manager.h" #include +#include #include #include +#include -#include "ir_context.h" +#include "source/opt/ir_context.h" namespace spvtools { namespace opt { namespace analysis { void DecorationManager::RemoveDecorationsFrom( - uint32_t id, std::function pred) { + uint32_t id, std::function pred) { const auto ids_iter = id_to_decoration_insts_.find(id); if (ids_iter == id_to_decoration_insts_.end()) return; TargetData& decorations_info = ids_iter->second; auto context = module_->context(); - std::vector insts_to_kill; + std::vector insts_to_kill; const bool is_group = !decorations_info.decorate_insts.empty(); // Schedule all direct decorations for removal if instructed as such by // |pred|. - for (ir::Instruction* inst : decorations_info.direct_decorations) + for (Instruction* inst : decorations_info.direct_decorations) if (pred(*inst)) insts_to_kill.push_back(inst); // For all groups being directly applied to |id|, remove |id| (and the // literal if |inst| is an OpGroupMemberDecorate) from the instruction // applying the group. - std::unordered_set indirect_decorations_to_remove; - for (ir::Instruction* inst : decorations_info.indirect_decorations) { + std::unordered_set indirect_decorations_to_remove; + for (Instruction* inst : decorations_info.indirect_decorations) { assert(inst->opcode() == SpvOpGroupDecorate || inst->opcode() == SpvOpGroupMemberDecorate); - std::vector group_decorations_to_keep; + std::vector group_decorations_to_keep; const uint32_t group_id = inst->GetSingleWordInOperand(0u); const auto group_iter = id_to_decoration_insts_.find(group_id); assert(group_iter != id_to_decoration_insts_.end() && "Unknown decoration group"); const auto& group_decorations = group_iter->second.direct_decorations; - for (ir::Instruction* decoration : group_decorations) { + for (Instruction* decoration : group_decorations) { if (!pred(*decoration)) group_decorations_to_keep.push_back(decoration); } @@ -97,9 +99,9 @@ void DecorationManager::RemoveDecorationsFrom( // If only some of the decorations should be kept, clone them and apply // them directly to |id|. if (!group_decorations_to_keep.empty()) { - for (ir::Instruction* decoration : group_decorations_to_keep) { + for (Instruction* decoration : group_decorations_to_keep) { // simply clone decoration and change |group_id| to |id| - std::unique_ptr new_inst( + std::unique_ptr new_inst( decoration->Clone(module_->context())); new_inst->SetInOperand(0, {id}); module_->AddAnnotationInst(std::move(new_inst)); @@ -113,22 +115,22 @@ void DecorationManager::RemoveDecorationsFrom( indirect_decorations.erase( std::remove_if( indirect_decorations.begin(), indirect_decorations.end(), - [&indirect_decorations_to_remove](const ir::Instruction* inst) { + [&indirect_decorations_to_remove](const Instruction* inst) { return indirect_decorations_to_remove.count(inst); }), indirect_decorations.end()); - for (ir::Instruction* inst : insts_to_kill) context->KillInst(inst); + for (Instruction* inst : insts_to_kill) context->KillInst(inst); insts_to_kill.clear(); // Schedule all instructions applying the group for removal if this group no // longer applies decorations, either directly or indirectly. if (is_group && decorations_info.direct_decorations.empty() && decorations_info.indirect_decorations.empty()) { - for (ir::Instruction* inst : decorations_info.decorate_insts) + for (Instruction* inst : decorations_info.decorate_insts) insts_to_kill.push_back(inst); } - for (ir::Instruction* inst : insts_to_kill) context->KillInst(inst); + for (Instruction* inst : insts_to_kill) context->KillInst(inst); if (decorations_info.direct_decorations.empty() && decorations_info.indirect_decorations.empty() && @@ -140,20 +142,20 @@ void DecorationManager::RemoveDecorationsFrom( } } -std::vector DecorationManager::GetDecorationsFor( +std::vector DecorationManager::GetDecorationsFor( uint32_t id, bool include_linkage) { - return InternalGetDecorationsFor(id, include_linkage); + return InternalGetDecorationsFor(id, include_linkage); } -std::vector DecorationManager::GetDecorationsFor( +std::vector DecorationManager::GetDecorationsFor( uint32_t id, bool include_linkage) const { return const_cast(this) - ->InternalGetDecorationsFor(id, include_linkage); + ->InternalGetDecorationsFor(id, include_linkage); } bool DecorationManager::HaveTheSameDecorations(uint32_t id1, uint32_t id2) const { - using InstructionList = std::vector; + using InstructionList = std::vector; using DecorationSet = std::set; const InstructionList decorations_for1 = GetDecorationsFor(id1, false); @@ -167,7 +169,7 @@ bool DecorationManager::HaveTheSameDecorations(uint32_t id1, [](const InstructionList& decoration_list, DecorationSet* decorate_set, DecorationSet* decorate_id_set, DecorationSet* decorate_string_set, DecorationSet* member_decorate_set) { - for (const ir::Instruction* inst : decoration_list) { + for (const Instruction* inst : decoration_list) { std::u32string decoration_payload; // Ignore the opcode and the target as we do not want them to be // compared. @@ -223,8 +225,8 @@ bool DecorationManager::HaveTheSameDecorations(uint32_t id1, // TODO(pierremoreau): If OpDecorateId is referencing an OpConstant, one could // check that the constants are the same rather than just // looking at the constant ID. -bool DecorationManager::AreDecorationsTheSame(const ir::Instruction* inst1, - const ir::Instruction* inst2, +bool DecorationManager::AreDecorationsTheSame(const Instruction* inst1, + const Instruction* inst2, bool ignore_target) const { switch (inst1->opcode()) { case SpvOpDecorate: @@ -250,11 +252,11 @@ void DecorationManager::AnalyzeDecorations() { if (!module_) return; // For each group and instruction, collect all their decoration instructions. - for (ir::Instruction& inst : module_->annotations()) { + for (Instruction& inst : module_->annotations()) { AddDecoration(&inst); } } -void DecorationManager::AddDecoration(ir::Instruction* inst) { +void DecorationManager::AddDecoration(Instruction* inst) { switch (inst->opcode()) { case SpvOpDecorate: case SpvOpDecorateId: @@ -295,8 +297,8 @@ std::vector DecorationManager::InternalGetDecorationsFor( const auto process_direct_decorations = [include_linkage, - &decorations](const std::vector& direct_decorations) { - for (ir::Instruction* inst : direct_decorations) { + &decorations](const std::vector& direct_decorations) { + for (Instruction* inst : direct_decorations) { const bool is_linkage = inst->opcode() == SpvOpDecorate && inst->GetSingleWordInOperand(1u) == SpvDecorationLinkageAttributes; @@ -308,7 +310,7 @@ std::vector DecorationManager::InternalGetDecorationsFor( process_direct_decorations(ids_iter->second.direct_decorations); // Process the decorations of all groups applied to |id|. - for (const ir::Instruction* inst : target_data.indirect_decorations) { + for (const Instruction* inst : target_data.indirect_decorations) { const uint32_t group_id = inst->GetSingleWordInOperand(0u); const auto group_iter = id_to_decoration_insts_.find(group_id); assert(group_iter != id_to_decoration_insts_.end() && "Unknown group ID"); @@ -320,8 +322,8 @@ std::vector DecorationManager::InternalGetDecorationsFor( bool DecorationManager::WhileEachDecoration( uint32_t id, uint32_t decoration, - std::function f) { - for (const ir::Instruction* inst : GetDecorationsFor(id, true)) { + std::function f) { + for (const Instruction* inst : GetDecorationsFor(id, true)) { switch (inst->opcode()) { case SpvOpMemberDecorate: if (inst->GetSingleWordInOperand(2) == decoration) { @@ -344,8 +346,8 @@ bool DecorationManager::WhileEachDecoration( void DecorationManager::ForEachDecoration( uint32_t id, uint32_t decoration, - std::function f) { - WhileEachDecoration(id, decoration, [&f](const ir::Instruction& inst) { + std::function f) { + WhileEachDecoration(id, decoration, [&f](const Instruction& inst) { f(inst); return true; }); @@ -355,9 +357,9 @@ void DecorationManager::CloneDecorations(uint32_t from, uint32_t to) { const auto decoration_list = id_to_decoration_insts_.find(from); if (decoration_list == id_to_decoration_insts_.end()) return; auto context = module_->context(); - for (ir::Instruction* inst : decoration_list->second.direct_decorations) { + for (Instruction* inst : decoration_list->second.direct_decorations) { // simply clone decoration and change |target-id| to |to| - std::unique_ptr new_inst(inst->Clone(module_->context())); + std::unique_ptr new_inst(inst->Clone(module_->context())); new_inst->SetInOperand(0, {to}); module_->AddAnnotationInst(std::move(new_inst)); auto decoration_iter = --module_->annotation_end(); @@ -365,15 +367,15 @@ void DecorationManager::CloneDecorations(uint32_t from, uint32_t to) { } // We need to copy the list of instructions as ForgetUses and AnalyzeUses are // going to modify it. - std::vector indirect_decorations = + std::vector indirect_decorations = decoration_list->second.indirect_decorations; - for (ir::Instruction* inst : indirect_decorations) { + for (Instruction* inst : indirect_decorations) { switch (inst->opcode()) { case SpvOpGroupDecorate: context->ForgetUses(inst); // add |to| to list of decorated id's inst->AddOperand( - ir::Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {to})); + Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {to})); context->AnalyzeUses(inst); break; case SpvOpGroupMemberDecorate: { @@ -381,10 +383,10 @@ void DecorationManager::CloneDecorations(uint32_t from, uint32_t to) { // for each (id == from), add (to, literal) as operands const uint32_t num_operands = inst->NumOperands(); for (uint32_t i = 1; i < num_operands; i += 2) { - ir::Operand op = inst->GetOperand(i); + Operand op = inst->GetOperand(i); if (op.words[0] == from) { // add new pair of operands: (to, literal) inst->AddOperand( - ir::Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {to})); + Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {to})); op = inst->GetOperand(i + 1); inst->AddOperand(std::move(op)); } @@ -398,8 +400,49 @@ void DecorationManager::CloneDecorations(uint32_t from, uint32_t to) { } } -void DecorationManager::RemoveDecoration(ir::Instruction* inst) { - const auto remove_from_container = [inst](std::vector& v) { +void DecorationManager::CloneDecorations( + uint32_t from, uint32_t to, + const std::vector& decorations_to_copy) { + const auto decoration_list = id_to_decoration_insts_.find(from); + if (decoration_list == id_to_decoration_insts_.end()) return; + auto context = module_->context(); + for (Instruction* inst : decoration_list->second.direct_decorations) { + if (std::find(decorations_to_copy.begin(), decorations_to_copy.end(), + inst->GetSingleWordInOperand(1)) == + decorations_to_copy.end()) { + continue; + } + + // Clone decoration and change |target-id| to |to|. + std::unique_ptr new_inst(inst->Clone(module_->context())); + new_inst->SetInOperand(0, {to}); + module_->AddAnnotationInst(std::move(new_inst)); + auto decoration_iter = --module_->annotation_end(); + context->AnalyzeUses(&*decoration_iter); + } + + // We need to copy the list of instructions as ForgetUses and AnalyzeUses are + // going to modify it. + std::vector indirect_decorations = + decoration_list->second.indirect_decorations; + for (Instruction* inst : indirect_decorations) { + switch (inst->opcode()) { + case SpvOpGroupDecorate: + CloneDecorations(inst->GetSingleWordInOperand(0), to, + decorations_to_copy); + break; + case SpvOpGroupMemberDecorate: { + assert(false && "The source id is not suppose to be a type."); + break; + } + default: + assert(false && "Unexpected decoration instruction"); + } + } +} + +void DecorationManager::RemoveDecoration(Instruction* inst) { + const auto remove_from_container = [inst](std::vector& v) { v.erase(std::remove(v.begin(), v.end(), inst), v.end()); }; @@ -431,7 +474,6 @@ void DecorationManager::RemoveDecoration(ir::Instruction* inst) { break; } } - } // namespace analysis } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/decoration_manager.h b/3rdparty/spirv-tools/source/opt/decoration_manager.h index 40ba13eca..a517ba2d8 100644 --- a/3rdparty/spirv-tools/source/opt/decoration_manager.h +++ b/3rdparty/spirv-tools/source/opt/decoration_manager.h @@ -12,26 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_DECORATION_MANAGER_H_ -#define LIBSPIRV_OPT_DECORATION_MANAGER_H_ +#ifndef SOURCE_OPT_DECORATION_MANAGER_H_ +#define SOURCE_OPT_DECORATION_MANAGER_H_ #include #include #include #include -#include "instruction.h" -#include "module.h" +#include "source/opt/instruction.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { namespace analysis { -// A class for analyzing and managing decorations in an ir::Module. +// A class for analyzing and managing decorations in an Module. class DecorationManager { public: // Constructs a decoration manager from the given |module| - explicit DecorationManager(ir::Module* module) : module_(module) { + explicit DecorationManager(Module* module) : module_(module) { AnalyzeDecorations(); } DecorationManager() = delete; @@ -42,23 +42,23 @@ class DecorationManager { // removed if they have no targets left, and OpDecorationGroup will be // removed if the group is not applied to anyone and contains no decorations. void RemoveDecorationsFrom(uint32_t id, - std::function pred = - [](const ir::Instruction&) { return true; }); + std::function pred = + [](const Instruction&) { return true; }); // Removes all decorations from the result id of |inst|. // // NOTE: This is only meant to be called from ir_context, as only metadata // will be removed, and no actual instruction. - void RemoveDecoration(ir::Instruction* inst); + void RemoveDecoration(Instruction* inst); // Returns a vector of all decorations affecting |id|. If a group is applied // to |id|, the decorations of that group are returned rather than the group // decoration instruction. If |include_linkage| is not set, linkage // decorations won't be returned. - std::vector GetDecorationsFor(uint32_t id, - bool include_linkage); - std::vector GetDecorationsFor( - uint32_t id, bool include_linkage) const; + std::vector GetDecorationsFor(uint32_t id, + bool include_linkage); + std::vector GetDecorationsFor(uint32_t id, + bool include_linkage) const; // Returns whether two IDs have the same decorations. Two SpvOpGroupDecorate // instructions that apply the same decorations but to different IDs, still // count as being the same. @@ -69,22 +69,21 @@ class DecorationManager { // // This is only valid for OpDecorate, OpMemberDecorate and OpDecorateId; it // will return false for other opcodes. - bool AreDecorationsTheSame(const ir::Instruction* inst1, - const ir::Instruction* inst2, + bool AreDecorationsTheSame(const Instruction* inst1, const Instruction* inst2, bool ignore_target) const; // |f| is run on each decoration instruction for |id| with decoration // |decoration|. Processed are all decorations which target |id| either // directly or indirectly by Decoration Groups. void ForEachDecoration(uint32_t id, uint32_t decoration, - std::function f); + std::function f); // |f| is run on each decoration instruction for |id| with decoration // |decoration|. Processes all decoration which target |id| either directly or // indirectly through decoration groups. If |f| returns false, iteration is // terminated and this function returns false. bool WhileEachDecoration(uint32_t id, uint32_t decoration, - std::function f); + std::function f); // Clone all decorations from one id |from|. // The cloned decorations are assigned to the given id |to| and are @@ -92,8 +91,14 @@ class DecorationManager { // This function does not check if the id |to| is already decorated. void CloneDecorations(uint32_t from, uint32_t to); + // Same as above, but only clone the decoration if the decoration operand is + // in |decorations_to_copy|. This function has the extra restriction that + // |from| and |to| must not be an object, not a type. + void CloneDecorations(uint32_t from, uint32_t to, + const std::vector& decorations_to_copy); + // Informs the decoration manager of a new decoration that it needs to track. - void AddDecoration(ir::Instruction* inst); + void AddDecoration(Instruction* inst); private: // Analyzes the defs and uses in the given |module| and populates data @@ -105,19 +110,19 @@ class DecorationManager { // Tracks decoration information of an ID. struct TargetData { - std::vector direct_decorations; // All decorate - // instructions applied - // to the tracked ID. - std::vector indirect_decorations; // All instructions - // applying a group to - // the tracked ID. - std::vector decorate_insts; // All decorate instructions - // applying the decorations - // of the tracked ID to - // targets. - // It is empty if the - // tracked ID is not a - // group. + std::vector direct_decorations; // All decorate + // instructions applied + // to the tracked ID. + std::vector indirect_decorations; // All instructions + // applying a group to + // the tracked ID. + std::vector decorate_insts; // All decorate instructions + // applying the decorations + // of the tracked ID to + // targets. + // It is empty if the + // tracked ID is not a + // group. }; // Mapping from ids to the instructions applying a decoration to those ids. @@ -127,11 +132,11 @@ class DecorationManager { // SpvOpMemberGroupDecorate). std::unordered_map id_to_decoration_insts_; // The enclosing module. - ir::Module* module_; + Module* module_; }; } // namespace analysis } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_DECORATION_MANAGER_H_ +#endif // SOURCE_OPT_DECORATION_MANAGER_H_ diff --git a/3rdparty/spirv-tools/source/opt/def_use_manager.cpp b/3rdparty/spirv-tools/source/opt/def_use_manager.cpp index 6e83b4197..4e3649382 100644 --- a/3rdparty/spirv-tools/source/opt/def_use_manager.cpp +++ b/3rdparty/spirv-tools/source/opt/def_use_manager.cpp @@ -12,18 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "def_use_manager.h" +#include "source/opt/def_use_manager.h" #include -#include "log.h" -#include "reflect.h" +#include "source/opt/log.h" +#include "source/opt/reflect.h" namespace spvtools { namespace opt { namespace analysis { -void DefUseManager::AnalyzeInstDef(ir::Instruction* inst) { +void DefUseManager::AnalyzeInstDef(Instruction* inst) { const uint32_t def_id = inst->result_id(); if (def_id != 0) { auto iter = id_to_def_.find(def_id); @@ -38,7 +38,7 @@ void DefUseManager::AnalyzeInstDef(ir::Instruction* inst) { } } -void DefUseManager::AnalyzeInstUse(ir::Instruction* inst) { +void DefUseManager::AnalyzeInstUse(Instruction* inst) { // Create entry for the given instruction. Note that the instruction may // not have any in-operands. In such cases, we still need a entry for those // instructions so this manager knows it has seen the instruction later. @@ -57,7 +57,7 @@ void DefUseManager::AnalyzeInstUse(ir::Instruction* inst) { case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: case SPV_OPERAND_TYPE_SCOPE_ID: { uint32_t use_id = inst->GetSingleWordOperand(i); - ir::Instruction* def = GetDef(use_id); + Instruction* def = GetDef(use_id); assert(def && "Definition is not registered."); id_to_users_.insert(UserEntry(def, inst)); used_ids->push_back(use_id); @@ -68,12 +68,12 @@ void DefUseManager::AnalyzeInstUse(ir::Instruction* inst) { } } -void DefUseManager::AnalyzeInstDefUse(ir::Instruction* inst) { +void DefUseManager::AnalyzeInstDefUse(Instruction* inst) { AnalyzeInstDef(inst); AnalyzeInstUse(inst); } -void DefUseManager::UpdateDefUse(ir::Instruction* inst) { +void DefUseManager::UpdateDefUse(Instruction* inst) { const uint32_t def_id = inst->result_id(); if (def_id != 0) { auto iter = id_to_def_.find(def_id); @@ -84,38 +84,37 @@ void DefUseManager::UpdateDefUse(ir::Instruction* inst) { AnalyzeInstUse(inst); } -ir::Instruction* DefUseManager::GetDef(uint32_t id) { +Instruction* DefUseManager::GetDef(uint32_t id) { auto iter = id_to_def_.find(id); if (iter == id_to_def_.end()) return nullptr; return iter->second; } -const ir::Instruction* DefUseManager::GetDef(uint32_t id) const { +const Instruction* DefUseManager::GetDef(uint32_t id) const { const auto iter = id_to_def_.find(id); if (iter == id_to_def_.end()) return nullptr; return iter->second; } DefUseManager::IdToUsersMap::const_iterator DefUseManager::UsersBegin( - const ir::Instruction* def) const { + const Instruction* def) const { return id_to_users_.lower_bound( - UserEntry(const_cast(def), nullptr)); + UserEntry(const_cast(def), nullptr)); } bool DefUseManager::UsersNotEnd(const IdToUsersMap::const_iterator& iter, const IdToUsersMap::const_iterator& cached_end, - const ir::Instruction* inst) const { + const Instruction* inst) const { return (iter != cached_end && iter->first == inst); } bool DefUseManager::UsersNotEnd(const IdToUsersMap::const_iterator& iter, - const ir::Instruction* inst) const { + const Instruction* inst) const { return UsersNotEnd(iter, id_to_users_.end(), inst); } bool DefUseManager::WhileEachUser( - const ir::Instruction* def, - const std::function& f) const { + const Instruction* def, const std::function& f) const { // Ensure that |def| has been registered. assert(def && (!def->HasResultId() || def == GetDef(def->result_id())) && "Definition is not registered."); @@ -129,27 +128,26 @@ bool DefUseManager::WhileEachUser( } bool DefUseManager::WhileEachUser( - uint32_t id, const std::function& f) const { + uint32_t id, const std::function& f) const { return WhileEachUser(GetDef(id), f); } void DefUseManager::ForEachUser( - const ir::Instruction* def, - const std::function& f) const { - WhileEachUser(def, [&f](ir::Instruction* user) { + const Instruction* def, const std::function& f) const { + WhileEachUser(def, [&f](Instruction* user) { f(user); return true; }); } void DefUseManager::ForEachUser( - uint32_t id, const std::function& f) const { + uint32_t id, const std::function& f) const { ForEachUser(GetDef(id), f); } bool DefUseManager::WhileEachUse( - const ir::Instruction* def, - const std::function& f) const { + const Instruction* def, + const std::function& f) const { // Ensure that |def| has been registered. assert(def && (!def->HasResultId() || def == GetDef(def->result_id())) && "Definition is not registered."); @@ -157,9 +155,9 @@ bool DefUseManager::WhileEachUse( auto end = id_to_users_.end(); for (auto iter = UsersBegin(def); UsersNotEnd(iter, end, def); ++iter) { - ir::Instruction* user = iter->second; + Instruction* user = iter->second; for (uint32_t idx = 0; idx != user->NumOperands(); ++idx) { - const ir::Operand& op = user->GetOperand(idx); + const Operand& op = user->GetOperand(idx); if (op.type != SPV_OPERAND_TYPE_RESULT_ID && spvIsIdType(op.type)) { if (def->result_id() == op.words[0]) { if (!f(user, idx)) return false; @@ -171,29 +169,27 @@ bool DefUseManager::WhileEachUse( } bool DefUseManager::WhileEachUse( - uint32_t id, - const std::function& f) const { + uint32_t id, const std::function& f) const { return WhileEachUse(GetDef(id), f); } void DefUseManager::ForEachUse( - const ir::Instruction* def, - const std::function& f) const { - WhileEachUse(def, [&f](ir::Instruction* user, uint32_t index) { + const Instruction* def, + const std::function& f) const { + WhileEachUse(def, [&f](Instruction* user, uint32_t index) { f(user, index); return true; }); } void DefUseManager::ForEachUse( - uint32_t id, - const std::function& f) const { + uint32_t id, const std::function& f) const { ForEachUse(GetDef(id), f); } -uint32_t DefUseManager::NumUsers(const ir::Instruction* def) const { +uint32_t DefUseManager::NumUsers(const Instruction* def) const { uint32_t count = 0; - ForEachUser(def, [&count](ir::Instruction*) { ++count; }); + ForEachUser(def, [&count](Instruction*) { ++count; }); return count; } @@ -201,9 +197,9 @@ uint32_t DefUseManager::NumUsers(uint32_t id) const { return NumUsers(GetDef(id)); } -uint32_t DefUseManager::NumUses(const ir::Instruction* def) const { +uint32_t DefUseManager::NumUses(const Instruction* def) const { uint32_t count = 0; - ForEachUse(def, [&count](ir::Instruction*, uint32_t) { ++count; }); + ForEachUse(def, [&count](Instruction*, uint32_t) { ++count; }); return count; } @@ -211,20 +207,20 @@ uint32_t DefUseManager::NumUses(uint32_t id) const { return NumUses(GetDef(id)); } -std::vector DefUseManager::GetAnnotations(uint32_t id) const { - std::vector annos; - const ir::Instruction* def = GetDef(id); +std::vector DefUseManager::GetAnnotations(uint32_t id) const { + std::vector annos; + const Instruction* def = GetDef(id); if (!def) return annos; - ForEachUser(def, [&annos](ir::Instruction* user) { - if (ir::IsAnnotationInst(user->opcode())) { + ForEachUser(def, [&annos](Instruction* user) { + if (IsAnnotationInst(user->opcode())) { annos.push_back(user); } }); return annos; } -void DefUseManager::AnalyzeDefUse(ir::Module* module) { +void DefUseManager::AnalyzeDefUse(Module* module) { if (!module) return; // Analyze all the defs before any uses to catch forward references. module->ForEachInst( @@ -233,7 +229,7 @@ void DefUseManager::AnalyzeDefUse(ir::Module* module) { std::bind(&DefUseManager::AnalyzeInstUse, this, std::placeholders::_1)); } -void DefUseManager::ClearInst(ir::Instruction* inst) { +void DefUseManager::ClearInst(Instruction* inst) { auto iter = inst_to_used_ids_.find(inst); if (iter != inst_to_used_ids_.end()) { EraseUseRecordsOfOperandIds(inst); @@ -250,14 +246,14 @@ void DefUseManager::ClearInst(ir::Instruction* inst) { } } -void DefUseManager::EraseUseRecordsOfOperandIds(const ir::Instruction* inst) { +void DefUseManager::EraseUseRecordsOfOperandIds(const Instruction* inst) { // Go through all ids used by this instruction, remove this instruction's // uses of them. auto iter = inst_to_used_ids_.find(inst); if (iter != inst_to_used_ids_.end()) { for (auto use_id : iter->second) { id_to_users_.erase( - UserEntry(GetDef(use_id), const_cast(inst))); + UserEntry(GetDef(use_id), const_cast(inst))); } inst_to_used_ids_.erase(inst); } diff --git a/3rdparty/spirv-tools/source/opt/def_use_manager.h b/3rdparty/spirv-tools/source/opt/def_use_manager.h index 206170398..0499e82b5 100644 --- a/3rdparty/spirv-tools/source/opt/def_use_manager.h +++ b/3rdparty/spirv-tools/source/opt/def_use_manager.h @@ -12,16 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_DEF_USE_MANAGER_H_ -#define LIBSPIRV_OPT_DEF_USE_MANAGER_H_ +#ifndef SOURCE_OPT_DEF_USE_MANAGER_H_ +#define SOURCE_OPT_DEF_USE_MANAGER_H_ #include #include #include +#include #include -#include "instruction.h" -#include "module.h" +#include "source/opt/instruction.h" +#include "source/opt/module.h" #include "spirv-tools/libspirv.hpp" namespace spvtools { @@ -33,7 +34,7 @@ namespace analysis { // * Ids referenced in OpSectionMerge & OpLoopMerge are considered as use. // * Ids referenced in OpPhi's in operands are considered as use. struct Use { - ir::Instruction* inst; // Instruction using the id. + Instruction* inst; // Instruction using the id. uint32_t operand_index; // logical operand index of the id use. This can be // the index of result type id. }; @@ -58,7 +59,7 @@ inline bool operator<(const Use& lhs, const Use& rhs) { // Definition should never be null. User can be null, however, such an entry // should be used only for searching (e.g. all users of a particular definition) // and never stored in a container. -using UserEntry = std::pair; +using UserEntry = std::pair; // Orders UserEntry for use in associative containers (i.e. less than ordering). // @@ -92,17 +93,17 @@ struct UserEntryLess { } }; -// A class for analyzing and managing defs and uses in an ir::Module. +// A class for analyzing and managing defs and uses in an Module. class DefUseManager { public: - using IdToDefMap = std::unordered_map; + using IdToDefMap = std::unordered_map; using IdToUsersMap = std::set; // Constructs a def-use manager from the given |module|. All internal messages // will be communicated to the outside via the given message |consumer|. This // instance only keeps a reference to the |consumer|, so the |consumer| should // outlive this instance. - DefUseManager(ir::Module* module) { AnalyzeDefUse(module); } + DefUseManager(Module* module) { AnalyzeDefUse(module); } DefUseManager(const DefUseManager&) = delete; DefUseManager(DefUseManager&&) = delete; @@ -110,20 +111,20 @@ class DefUseManager { DefUseManager& operator=(DefUseManager&&) = delete; // Analyzes the defs in the given |inst|. - void AnalyzeInstDef(ir::Instruction* inst); + void AnalyzeInstDef(Instruction* inst); // Analyzes the uses in the given |inst|. // // All operands of |inst| must be analyzed as defs. - void AnalyzeInstUse(ir::Instruction* inst); + void AnalyzeInstUse(Instruction* inst); // Analyzes the defs and uses in the given |inst|. - void AnalyzeInstDefUse(ir::Instruction* inst); + void AnalyzeInstDefUse(Instruction* inst); // Returns the def instruction for the given |id|. If there is no instruction // defining |id|, returns nullptr. - ir::Instruction* GetDef(uint32_t id); - const ir::Instruction* GetDef(uint32_t id) const; + Instruction* GetDef(uint32_t id); + const Instruction* GetDef(uint32_t id) const; // Runs the given function |f| on each unique user instruction of |def| (or // |id|). @@ -132,10 +133,10 @@ class DefUseManager { // only be visited once. // // |def| (or |id|) must be registered as a definition. - void ForEachUser(const ir::Instruction* def, - const std::function& f) const; + void ForEachUser(const Instruction* def, + const std::function& f) const; void ForEachUser(uint32_t id, - const std::function& f) const; + const std::function& f) const; // Runs the given function |f| on each unique user instruction of |def| (or // |id|). If |f| returns false, iteration is terminated and this function @@ -145,10 +146,10 @@ class DefUseManager { // be only be visited once. // // |def| (or |id|) must be registered as a definition. - bool WhileEachUser(const ir::Instruction* def, - const std::function& f) const; + bool WhileEachUser(const Instruction* def, + const std::function& f) const; bool WhileEachUser(uint32_t id, - const std::function& f) const; + const std::function& f) const; // Runs the given function |f| on each unique use of |def| (or // |id|). @@ -157,12 +158,12 @@ class DefUseManager { // visited separately. // // |def| (or |id|) must be registered as a definition. - void ForEachUse(const ir::Instruction* def, - const std::function& f) const; - void ForEachUse(uint32_t id, - const std::function& f) const; + void ForEachUse( + const Instruction* def, + const std::function& f) const; + void ForEachUse( + uint32_t id, + const std::function& f) const; // Runs the given function |f| on each unique use of |def| (or // |id|). If |f| returns false, iteration is terminated and this function @@ -172,19 +173,19 @@ class DefUseManager { // visited separately. // // |def| (or |id|) must be registered as a definition. - bool WhileEachUse(const ir::Instruction* def, - const std::function& f) const; - bool WhileEachUse(uint32_t id, - const std::function& f) const; + bool WhileEachUse( + const Instruction* def, + const std::function& f) const; + bool WhileEachUse( + uint32_t id, + const std::function& f) const; // Returns the number of users of |def| (or |id|). - uint32_t NumUsers(const ir::Instruction* def) const; + uint32_t NumUsers(const Instruction* def) const; uint32_t NumUsers(uint32_t id) const; // Returns the number of uses of |def| (or |id|). - uint32_t NumUses(const ir::Instruction* def) const; + uint32_t NumUses(const Instruction* def) const; uint32_t NumUses(uint32_t id) const; // Returns the annotation instrunctions which are a direct use of the given @@ -192,7 +193,7 @@ class DefUseManager { // group(s), this function will just return the OpGroupDecorate // instrcution(s) which refer to the given id as an operand. The OpDecorate // instructions which decorate the decoration group will not be returned. - std::vector GetAnnotations(uint32_t id) const; + std::vector GetAnnotations(uint32_t id) const; // Returns the map from ids to their def instructions. const IdToDefMap& id_to_defs() const { return id_to_def_; } @@ -204,10 +205,10 @@ class DefUseManager { // record: |inst| uses an |id|, will be removed from the use records of |id|. // If |inst| defines an result id, the use record of this result id will also // be removed. Does nothing if |inst| was not analyzed before. - void ClearInst(ir::Instruction* inst); + void ClearInst(Instruction* inst); // Erases the records that a given instruction uses its operand ids. - void EraseUseRecordsOfOperandIds(const ir::Instruction* inst); + void EraseUseRecordsOfOperandIds(const Instruction* inst); friend bool operator==(const DefUseManager&, const DefUseManager&); friend bool operator!=(const DefUseManager& lhs, const DefUseManager& rhs) { @@ -216,15 +217,15 @@ class DefUseManager { // If |inst| has not already been analysed, then analyses its defintion and // uses. - void UpdateDefUse(ir::Instruction* inst); + void UpdateDefUse(Instruction* inst); private: using InstToUsedIdsMap = - std::unordered_map>; + std::unordered_map>; // Returns the first location that {|def|, nullptr} could be inserted into the // users map without violating ordering. - IdToUsersMap::const_iterator UsersBegin(const ir::Instruction* def) const; + IdToUsersMap::const_iterator UsersBegin(const Instruction* def) const; // Returns true if |iter| has not reached the end of |def|'s users. // @@ -233,14 +234,14 @@ class DefUseManager { // against |cached_end| for validity before other checks. This allows caching // the map's end which is a performance improvement on some platforms. bool UsersNotEnd(const IdToUsersMap::const_iterator& iter, - const ir::Instruction* def) const; + const Instruction* def) const; bool UsersNotEnd(const IdToUsersMap::const_iterator& iter, const IdToUsersMap::const_iterator& cached_end, - const ir::Instruction* def) const; + const Instruction* def) const; // Analyzes the defs and uses in the given |module| and populates data // structures in this class. Does nothing if |module| is nullptr. - void AnalyzeDefUse(ir::Module* module); + void AnalyzeDefUse(Module* module); IdToDefMap id_to_def_; // Mapping from ids to their definitions IdToUsersMap id_to_users_; // Mapping from ids to their users @@ -252,4 +253,4 @@ class DefUseManager { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_DEF_USE_MANAGER_H_ +#endif // SOURCE_OPT_DEF_USE_MANAGER_H_ diff --git a/3rdparty/spirv-tools/source/opt/dominator_analysis.cpp b/3rdparty/spirv-tools/source/opt/dominator_analysis.cpp index 95f2aa420..aef43e69f 100644 --- a/3rdparty/spirv-tools/source/opt/dominator_analysis.cpp +++ b/3rdparty/spirv-tools/source/opt/dominator_analysis.cpp @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "dominator_analysis.h" +#include "source/opt/dominator_analysis.h" #include -#include "ir_context.h" +#include "source/opt/ir_context.h" namespace spvtools { namespace opt { -ir::BasicBlock* DominatorAnalysisBase::CommonDominator( - ir::BasicBlock* b1, ir::BasicBlock* b2) const { +BasicBlock* DominatorAnalysisBase::CommonDominator(BasicBlock* b1, + BasicBlock* b2) const { if (!b1 || !b2) return nullptr; - std::unordered_set seen; - ir::BasicBlock* block = b1; + std::unordered_set seen; + BasicBlock* block = b1; while (block && seen.insert(block).second) { block = ImmediateDominator(block); } @@ -39,8 +39,7 @@ ir::BasicBlock* DominatorAnalysisBase::CommonDominator( return block; } -bool DominatorAnalysisBase::Dominates(ir::Instruction* a, - ir::Instruction* b) const { +bool DominatorAnalysisBase::Dominates(Instruction* a, Instruction* b) const { if (!a || !b) { return false; } @@ -49,23 +48,19 @@ bool DominatorAnalysisBase::Dominates(ir::Instruction* a, return true; } - ir::BasicBlock* bb_a = a->context()->get_instr_block(a); - ir::BasicBlock* bb_b = b->context()->get_instr_block(b); + BasicBlock* bb_a = a->context()->get_instr_block(a); + BasicBlock* bb_b = b->context()->get_instr_block(b); if (bb_a != bb_b) { return tree_.Dominates(bb_a, bb_b); } - for (ir::Instruction& inst : *bb_a) { - if (&inst == a) { + Instruction* current_inst = a; + while ((current_inst = current_inst->NextNode())) { + if (current_inst == b) { return true; - } else if (&inst == b) { - return false; } } - assert(false && - "We did not find the load or store in the block they are " - "supposed to be in."); return false; } diff --git a/3rdparty/spirv-tools/source/opt/dominator_analysis.h b/3rdparty/spirv-tools/source/opt/dominator_analysis.h index ea10ce385..a94120a55 100644 --- a/3rdparty/spirv-tools/source/opt/dominator_analysis.h +++ b/3rdparty/spirv-tools/source/opt/dominator_analysis.h @@ -12,14 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_DOMINATOR_ANALYSIS_PASS_H_ -#define LIBSPIRV_OPT_DOMINATOR_ANALYSIS_PASS_H_ +#ifndef SOURCE_OPT_DOMINATOR_ANALYSIS_H_ +#define SOURCE_OPT_DOMINATOR_ANALYSIS_H_ #include #include -#include "dominator_tree.h" -#include "module.h" +#include "source/opt/dominator_tree.h" namespace spvtools { namespace opt { @@ -30,13 +29,12 @@ class DominatorAnalysisBase { explicit DominatorAnalysisBase(bool is_post_dom) : tree_(is_post_dom) {} // Calculates the dominator (or postdominator) tree for given function |f|. - inline void InitializeTree(const ir::Function* f, const ir::CFG& cfg) { - tree_.InitializeTree(f, cfg); + inline void InitializeTree(const CFG& cfg, const Function* f) { + tree_.InitializeTree(cfg, f); } // Returns true if BasicBlock |a| dominates BasicBlock |b|. - inline bool Dominates(const ir::BasicBlock* a, - const ir::BasicBlock* b) const { + inline bool Dominates(const BasicBlock* a, const BasicBlock* b) const { if (!a || !b) return false; return Dominates(a->id(), b->id()); } @@ -48,11 +46,11 @@ class DominatorAnalysisBase { } // Returns true if instruction |a| dominates instruction |b|. - bool Dominates(ir::Instruction* a, ir::Instruction* b) const; + bool Dominates(Instruction* a, Instruction* b) const; // Returns true if BasicBlock |a| strictly dominates BasicBlock |b|. - inline bool StrictlyDominates(const ir::BasicBlock* a, - const ir::BasicBlock* b) const { + inline bool StrictlyDominates(const BasicBlock* a, + const BasicBlock* b) const { if (!a || !b) return false; return StrictlyDominates(a->id(), b->id()); } @@ -65,19 +63,19 @@ class DominatorAnalysisBase { // Returns the immediate dominator of |node| or returns nullptr if it is has // no dominator. - inline ir::BasicBlock* ImmediateDominator(const ir::BasicBlock* node) const { + inline BasicBlock* ImmediateDominator(const BasicBlock* node) const { if (!node) return nullptr; return tree_.ImmediateDominator(node); } // Returns the immediate dominator of |node_id| or returns nullptr if it is // has no dominator. Same as above but operates on IDs. - inline ir::BasicBlock* ImmediateDominator(uint32_t node_id) const { + inline BasicBlock* ImmediateDominator(uint32_t node_id) const { return tree_.ImmediateDominator(node_id); } // Returns true if |node| is reachable from the entry. - inline bool IsReachable(const ir::BasicBlock* node) const { + inline bool IsReachable(const BasicBlock* node) const { if (!node) return false; return tree_.ReachableFromRoots(node->id()); } @@ -116,7 +114,7 @@ class DominatorAnalysisBase { // Returns the most immediate basic block that dominates both |b1| and |b2|. // If there is no such basic block, nullptr is returned. - ir::BasicBlock* CommonDominator(ir::BasicBlock* b1, ir::BasicBlock* b2) const; + BasicBlock* CommonDominator(BasicBlock* b1, BasicBlock* b2) const; protected: DominatorTree tree_; @@ -137,4 +135,4 @@ class PostDominatorAnalysis : public DominatorAnalysisBase { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_DOMINATOR_ANALYSIS_PASS_H_ +#endif // SOURCE_OPT_DOMINATOR_ANALYSIS_H_ diff --git a/3rdparty/spirv-tools/source/opt/dominator_tree.cpp b/3rdparty/spirv-tools/source/opt/dominator_tree.cpp index 39339d06b..c9346e1c5 100644 --- a/3rdparty/spirv-tools/source/opt/dominator_tree.cpp +++ b/3rdparty/spirv-tools/source/opt/dominator_tree.cpp @@ -16,11 +16,9 @@ #include #include -#include "cfa.h" -#include "dominator_tree.h" - -using namespace spvtools; -using namespace spvtools::opt; +#include "source/cfa.h" +#include "source/opt/dominator_tree.h" +#include "source/opt/ir_context.h" // Calculates the dominator or postdominator tree for a given function. // 1 - Compute the successors and predecessors for each BasicBlock. We add a @@ -39,13 +37,15 @@ using namespace spvtools::opt; // preorder and postorder index of each node. We use these indexes to compare // nodes against each other for domination checks. +namespace spvtools { +namespace opt { namespace { // Wrapper around CFA::DepthFirstTraversal to provide an interface to perform // depth first search on generic BasicBlock types. Will call post and pre order // user defined functions during traversal // -// BBType - BasicBlock type. Will either be ir::BasicBlock or DominatorTreeNode +// BBType - BasicBlock type. Will either be BasicBlock or DominatorTreeNode // SuccessorLambda - Lamdba matching the signature of 'const // std::vector*(const BBType *A)'. Will return a vector of the nodes // succeding BasicBlock A. @@ -66,7 +66,7 @@ static void DepthFirstSearch(const BBType* bb, SuccessorLambda successors, // depth first search on generic BasicBlock types. This overload is for only // performing user defined post order. // -// BBType - BasicBlock type. Will either be ir::BasicBlock or DominatorTreeNode +// BBType - BasicBlock type. Will either be BasicBlock or DominatorTreeNode // SuccessorLambda - Lamdba matching the signature of 'const // std::vector*(const BBType *A)'. Will return a vector of the nodes // succeding BasicBlock A. @@ -84,7 +84,7 @@ static void DepthFirstSearchPostOrder(const BBType* bb, // Small type trait to get the function class type. template struct GetFunctionClass { - using FunctionType = ir::Function; + using FunctionType = Function; }; // Helper class to compute predecessors and successors for each Basic Block in a @@ -98,7 +98,7 @@ struct GetFunctionClass { // returned by this class will be predecessors in the original CFG. template class BasicBlockSuccessorHelper { - // This should eventually become const ir::BasicBlock. + // This should eventually become const BasicBlock. using BasicBlock = BBType; using Function = typename GetFunctionClass::FunctionType; @@ -214,16 +214,13 @@ void BasicBlockSuccessorHelper::CreateSuccessorMap( } // namespace -namespace spvtools { -namespace opt { - bool DominatorTree::StrictlyDominates(uint32_t a, uint32_t b) const { if (a == b) return false; return Dominates(a, b); } -bool DominatorTree::StrictlyDominates(const ir::BasicBlock* a, - const ir::BasicBlock* b) const { +bool DominatorTree::StrictlyDominates(const BasicBlock* a, + const BasicBlock* b) const { return DominatorTree::StrictlyDominates(a->id(), b->id()); } @@ -251,17 +248,15 @@ bool DominatorTree::Dominates(const DominatorTreeNode* a, a->dfs_num_post_ > b->dfs_num_post_; } -bool DominatorTree::Dominates(const ir::BasicBlock* A, - const ir::BasicBlock* B) const { +bool DominatorTree::Dominates(const BasicBlock* A, const BasicBlock* B) const { return Dominates(A->id(), B->id()); } -ir::BasicBlock* DominatorTree::ImmediateDominator( - const ir::BasicBlock* A) const { +BasicBlock* DominatorTree::ImmediateDominator(const BasicBlock* A) const { return ImmediateDominator(A->id()); } -ir::BasicBlock* DominatorTree::ImmediateDominator(uint32_t a) const { +BasicBlock* DominatorTree::ImmediateDominator(uint32_t a) const { // Check that A is a valid node in the tree. auto a_itr = nodes_.find(a); if (a_itr == nodes_.end()) return nullptr; @@ -275,7 +270,7 @@ ir::BasicBlock* DominatorTree::ImmediateDominator(uint32_t a) const { return node->parent_->bb_; } -DominatorTreeNode* DominatorTree::GetOrInsertNode(ir::BasicBlock* bb) { +DominatorTreeNode* DominatorTree::GetOrInsertNode(BasicBlock* bb) { DominatorTreeNode* dtn = nullptr; std::map::iterator node_iter = @@ -283,28 +278,29 @@ DominatorTreeNode* DominatorTree::GetOrInsertNode(ir::BasicBlock* bb) { if (node_iter == nodes_.end()) { dtn = &nodes_.emplace(std::make_pair(bb->id(), DominatorTreeNode{bb})) .first->second; - } else + } else { dtn = &node_iter->second; + } return dtn; } void DominatorTree::GetDominatorEdges( - const ir::Function* f, const ir::BasicBlock* dummy_start_node, - std::vector>* edges) { + const Function* f, const BasicBlock* dummy_start_node, + std::vector>* edges) { // Each time the depth first traversal calls the postorder callback // std::function we push that node into the postorder vector to create our // postorder list. - std::vector postorder; - auto postorder_function = [&](const ir::BasicBlock* b) { + std::vector postorder; + auto postorder_function = [&](const BasicBlock* b) { postorder.push_back(b); }; - // CFA::CalculateDominators requires std::vector + // CFA::CalculateDominators requires std::vector // BB are derived from F, so we need to const cast it at some point // no modification is made on F. - BasicBlockSuccessorHelper helper{ - *const_cast(f), dummy_start_node, postdominator_}; + BasicBlockSuccessorHelper helper{ + *const_cast(f), dummy_start_node, postdominator_}; // The successor function tells DepthFirstTraversal how to move to successive // nodes by providing an interface to get a list of successor nodes from any @@ -320,11 +316,10 @@ void DominatorTree::GetDominatorEdges( // versa. DepthFirstSearchPostOrder(dummy_start_node, successor_functor, postorder_function); - *edges = - CFA::CalculateDominators(postorder, predecessor_functor); + *edges = CFA::CalculateDominators(postorder, predecessor_functor); } -void DominatorTree::InitializeTree(const ir::Function* f, const ir::CFG& cfg) { +void DominatorTree::InitializeTree(const CFG& cfg, const Function* f) { ClearTree(); // Skip over empty functions. @@ -332,11 +327,11 @@ void DominatorTree::InitializeTree(const ir::Function* f, const ir::CFG& cfg) { return; } - const ir::BasicBlock* dummy_start_node = + const BasicBlock* dummy_start_node = postdominator_ ? cfg.pseudo_exit_block() : cfg.pseudo_entry_block(); // Get the immediate dominator for each node. - std::vector> edges; + std::vector> edges; GetDominatorEdges(f, dummy_start_node, &edges); // Transform the vector into the tree structure which we can use to diff --git a/3rdparty/spirv-tools/source/opt/dominator_tree.h b/3rdparty/spirv-tools/source/opt/dominator_tree.h index 39d5e0297..0024bc508 100644 --- a/3rdparty/spirv-tools/source/opt/dominator_tree.h +++ b/3rdparty/spirv-tools/source/opt/dominator_tree.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_DOMINATOR_ANALYSIS_TREE_H_ -#define LIBSPIRV_OPT_DOMINATOR_ANALYSIS_TREE_H_ +#ifndef SOURCE_OPT_DOMINATOR_TREE_H_ +#define SOURCE_OPT_DOMINATOR_TREE_H_ #include #include @@ -21,9 +21,8 @@ #include #include -#include "cfg.h" -#include "module.h" -#include "tree_iterator.h" +#include "source/opt/cfg.h" +#include "source/opt/tree_iterator.h" namespace spvtools { namespace opt { @@ -31,7 +30,7 @@ namespace opt { // children. It also contains two values, for the pre and post indexes in the // tree which are used to compare two nodes. struct DominatorTreeNode { - explicit DominatorTreeNode(ir::BasicBlock* bb) + explicit DominatorTreeNode(BasicBlock* bb) : bb_(bb), parent_(nullptr), children_({}), @@ -77,7 +76,7 @@ struct DominatorTreeNode { inline uint32_t id() const { return bb_->id(); } - ir::BasicBlock* bb_; + BasicBlock* bb_; DominatorTreeNode* parent_; std::vector children_; @@ -156,12 +155,13 @@ class DominatorTree { // Dumps the tree in the graphvis dot format into the |out_stream|. void DumpTreeAsDot(std::ostream& out_stream) const; - // Build the (post-)dominator tree for the function |f| - // Any existing data will be overwritten - void InitializeTree(const ir::Function* f, const ir::CFG& cfg); + // Build the (post-)dominator tree for the given control flow graph + // |cfg| and the function |f|. |f| must exist in the |cfg|. Any + // existing data in the dominator tree will be overwritten + void InitializeTree(const CFG& cfg, const Function* f); // Check if the basic block |a| dominates the basic block |b|. - bool Dominates(const ir::BasicBlock* a, const ir::BasicBlock* b) const; + bool Dominates(const BasicBlock* a, const BasicBlock* b) const; // Check if the basic block id |a| dominates the basic block id |b|. bool Dominates(uint32_t a, uint32_t b) const; @@ -170,8 +170,7 @@ class DominatorTree { bool Dominates(const DominatorTreeNode* a, const DominatorTreeNode* b) const; // Check if the basic block |a| strictly dominates the basic block |b|. - bool StrictlyDominates(const ir::BasicBlock* a, - const ir::BasicBlock* b) const; + bool StrictlyDominates(const BasicBlock* a, const BasicBlock* b) const; // Check if the basic block id |a| strictly dominates the basic block id |b|. bool StrictlyDominates(uint32_t a, uint32_t b) const; @@ -182,15 +181,15 @@ class DominatorTree { const DominatorTreeNode* b) const; // Returns the immediate dominator of basic block |a|. - ir::BasicBlock* ImmediateDominator(const ir::BasicBlock* A) const; + BasicBlock* ImmediateDominator(const BasicBlock* A) const; // Returns the immediate dominator of basic block id |a|. - ir::BasicBlock* ImmediateDominator(uint32_t a) const; + BasicBlock* ImmediateDominator(uint32_t a) const; // Returns true if the basic block |a| is reachable by this tree. A node would // be unreachable if it cannot be reached by traversal from the start node or // for a postdominator tree, cannot be reached from the exit nodes. - inline bool ReachableFromRoots(const ir::BasicBlock* a) const { + inline bool ReachableFromRoots(const BasicBlock* a) const { if (!a) return false; return ReachableFromRoots(a->id()); } @@ -242,12 +241,12 @@ class DominatorTree { // Returns the DominatorTreeNode associated with the basic block |bb|. // If the |bb| is unknown to the dominator tree, it returns null. - inline DominatorTreeNode* GetTreeNode(ir::BasicBlock* bb) { + inline DominatorTreeNode* GetTreeNode(BasicBlock* bb) { return GetTreeNode(bb->id()); } // Returns the DominatorTreeNode associated with the basic block |bb|. // If the |bb| is unknown to the dominator tree, it returns null. - inline const DominatorTreeNode* GetTreeNode(ir::BasicBlock* bb) const { + inline const DominatorTreeNode* GetTreeNode(BasicBlock* bb) const { return GetTreeNode(bb->id()); } @@ -272,7 +271,7 @@ class DominatorTree { // Adds the basic block |bb| to the tree structure if it doesn't already // exist. - DominatorTreeNode* GetOrInsertNode(ir::BasicBlock* bb); + DominatorTreeNode* GetOrInsertNode(BasicBlock* bb); // Recomputes the DF numbering of the tree. void ResetDFNumbering(); @@ -287,8 +286,8 @@ class DominatorTree { // pair is its immediate dominator. // The root of the tree has themself as immediate dominator. void GetDominatorEdges( - const ir::Function* f, const ir::BasicBlock* dummy_start_node, - std::vector>* edges); + const Function* f, const BasicBlock* dummy_start_node, + std::vector>* edges); // The roots of the tree. std::vector roots_; @@ -303,4 +302,4 @@ class DominatorTree { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_DOMINATOR_ANALYSIS_TREE_H_ +#endif // SOURCE_OPT_DOMINATOR_TREE_H_ diff --git a/3rdparty/spirv-tools/source/opt/eliminate_dead_constant_pass.cpp b/3rdparty/spirv-tools/source/opt/eliminate_dead_constant_pass.cpp index 5e299c6d2..d368bd145 100644 --- a/3rdparty/spirv-tools/source/opt/eliminate_dead_constant_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/eliminate_dead_constant_pass.cpp @@ -12,36 +12,37 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "eliminate_dead_constant_pass.h" +#include "source/opt/eliminate_dead_constant_pass.h" #include #include #include +#include -#include "def_use_manager.h" -#include "ir_context.h" -#include "log.h" -#include "reflect.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/log.h" +#include "source/opt/reflect.h" namespace spvtools { namespace opt { -Pass::Status EliminateDeadConstantPass::Process(ir::IRContext* irContext) { - std::unordered_set working_list; +Pass::Status EliminateDeadConstantPass::Process() { + std::unordered_set working_list; // Traverse all the instructions to get the initial set of dead constants as // working list and count number of real uses for constants. Uses in // annotation instructions do not count. - std::unordered_map use_counts; - std::vector constants = irContext->GetConstants(); + std::unordered_map use_counts; + std::vector constants = context()->GetConstants(); for (auto* c : constants) { uint32_t const_id = c->result_id(); size_t count = 0; - irContext->get_def_use_mgr()->ForEachUse( - const_id, [&count](ir::Instruction* user, uint32_t index) { + context()->get_def_use_mgr()->ForEachUse( + const_id, [&count](Instruction* user, uint32_t index) { (void)index; SpvOp op = user->opcode(); - if (!(ir::IsAnnotationInst(op) || ir::IsDebug1Inst(op) || - ir::IsDebug2Inst(op) || ir::IsDebug3Inst(op))) { + if (!(IsAnnotationInst(op) || IsDebug1Inst(op) || IsDebug2Inst(op) || + IsDebug3Inst(op))) { ++count; } }); @@ -53,9 +54,9 @@ Pass::Status EliminateDeadConstantPass::Process(ir::IRContext* irContext) { // Start from the constants with 0 uses, back trace through the def-use chain // to find all dead constants. - std::unordered_set dead_consts; + std::unordered_set dead_consts; while (!working_list.empty()) { - ir::Instruction* inst = *working_list.begin(); + Instruction* inst = *working_list.begin(); // Back propagate if the instruction contains IDs in its operands. switch (inst->opcode()) { case SpvOp::SpvOpConstantComposite: @@ -68,8 +69,8 @@ Pass::Status EliminateDeadConstantPass::Process(ir::IRContext* irContext) { continue; } uint32_t operand_id = inst->GetSingleWordInOperand(i); - ir::Instruction* def_inst = - irContext->get_def_use_mgr()->GetDef(operand_id); + Instruction* def_inst = + context()->get_def_use_mgr()->GetDef(operand_id); // If the use_count does not have any count for the def_inst, // def_inst must not be a constant, and should be ignored here. if (!use_counts.count(def_inst)) { @@ -93,7 +94,7 @@ Pass::Status EliminateDeadConstantPass::Process(ir::IRContext* irContext) { // Turn all dead instructions and uses of them to nop for (auto* dc : dead_consts) { - irContext->KillDef(dc->result_id()); + context()->KillDef(dc->result_id()); } return dead_consts.empty() ? Status::SuccessWithoutChange : Status::SuccessWithChange; diff --git a/3rdparty/spirv-tools/source/opt/eliminate_dead_constant_pass.h b/3rdparty/spirv-tools/source/opt/eliminate_dead_constant_pass.h index 3ff69f5c7..01692dbf4 100644 --- a/3rdparty/spirv-tools/source/opt/eliminate_dead_constant_pass.h +++ b/3rdparty/spirv-tools/source/opt/eliminate_dead_constant_pass.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_ELIMINATE_DEAD_CONSTANT_PASS_H_ -#define LIBSPIRV_OPT_ELIMINATE_DEAD_CONSTANT_PASS_H_ +#ifndef SOURCE_OPT_ELIMINATE_DEAD_CONSTANT_PASS_H_ +#define SOURCE_OPT_ELIMINATE_DEAD_CONSTANT_PASS_H_ -#include "ir_context.h" -#include "module.h" -#include "pass.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -26,10 +26,10 @@ namespace opt { class EliminateDeadConstantPass : public Pass { public: const char* name() const override { return "eliminate-dead-const"; } - Status Process(ir::IRContext*) override; + Status Process() override; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_ELIMINATE_DEAD_CONSTANT_PASS_H_ +#endif // SOURCE_OPT_ELIMINATE_DEAD_CONSTANT_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/eliminate_dead_functions_pass.cpp b/3rdparty/spirv-tools/source/opt/eliminate_dead_functions_pass.cpp index 8f9748a67..5be983a58 100644 --- a/3rdparty/spirv-tools/source/opt/eliminate_dead_functions_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/eliminate_dead_functions_pass.cpp @@ -12,21 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "eliminate_dead_functions_pass.h" -#include "ir_context.h" +#include "source/opt/eliminate_dead_functions_pass.h" #include +#include "source/opt/ir_context.h" + namespace spvtools { namespace opt { -Pass::Status EliminateDeadFunctionsPass::Process(ir::IRContext* c) { - InitializeProcessing(c); - +Pass::Status EliminateDeadFunctionsPass::Process() { // Identify live functions first. Those that are not live // are dead. - std::unordered_set live_function_set; - ProcessFunction mark_live = [&live_function_set](ir::Function* fp) { + std::unordered_set live_function_set; + ProcessFunction mark_live = [&live_function_set](Function* fp) { live_function_set.insert(fp); return false; }; @@ -48,10 +47,10 @@ Pass::Status EliminateDeadFunctionsPass::Process(ir::IRContext* c) { : Pass::Status::SuccessWithoutChange; } -void EliminateDeadFunctionsPass::EliminateFunction(ir::Function* func) { +void EliminateDeadFunctionsPass::EliminateFunction(Function* func) { // Remove all of the instruction in the function body - func->ForEachInst( - [this](ir::Instruction* inst) { context()->KillInst(inst); }, true); + func->ForEachInst([this](Instruction* inst) { context()->KillInst(inst); }, + true); } } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/eliminate_dead_functions_pass.h b/3rdparty/spirv-tools/source/opt/eliminate_dead_functions_pass.h index adb41bb39..165e9a6b5 100644 --- a/3rdparty/spirv-tools/source/opt/eliminate_dead_functions_pass.h +++ b/3rdparty/spirv-tools/source/opt/eliminate_dead_functions_pass.h @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_ELIMINATE_DEAD_FUNCTIONS_PASS_H_ -#define LIBSPIRV_OPT_ELIMINATE_DEAD_FUNCTIONS_PASS_H_ +#ifndef SOURCE_OPT_ELIMINATE_DEAD_FUNCTIONS_PASS_H_ +#define SOURCE_OPT_ELIMINATE_DEAD_FUNCTIONS_PASS_H_ -#include "def_use_manager.h" -#include "function.h" -#include "mem_pass.h" -#include "module.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/function.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { @@ -27,17 +27,17 @@ namespace opt { class EliminateDeadFunctionsPass : public MemPass { public: const char* name() const override { return "eliminate-dead-functions"; } - Status Process(ir::IRContext* c) override; + Status Process() override; - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse; + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse; } private: - void EliminateFunction(ir::Function* func); + void EliminateFunction(Function* func); }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_ELIMINATE_DEAD_FUNCTIONS_PASS_H_ +#endif // SOURCE_OPT_ELIMINATE_DEAD_FUNCTIONS_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/feature_manager.cpp b/3rdparty/spirv-tools/source/opt/feature_manager.cpp index f9b91bdbe..b7fc16a50 100644 --- a/3rdparty/spirv-tools/source/opt/feature_manager.cpp +++ b/3rdparty/spirv-tools/source/opt/feature_manager.cpp @@ -12,27 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "feature_manager.h" +#include "source/opt/feature_manager.h" + #include #include +#include -#include "enum_string_mapping.h" +#include "source/enum_string_mapping.h" namespace spvtools { namespace opt { -void FeatureManager::Analyze(ir::Module* module) { +void FeatureManager::Analyze(Module* module) { AddExtensions(module); AddCapabilities(module); AddExtInstImportIds(module); } -void FeatureManager::AddExtensions(ir::Module* module) { +void FeatureManager::AddExtensions(Module* module) { for (auto ext : module->extensions()) { const std::string name = reinterpret_cast(ext.GetInOperand(0u).words.data()); - libspirv::Extension extension; - if (libspirv::GetExtensionFromString(name.c_str(), &extension)) { + Extension extension; + if (GetExtensionFromString(name.c_str(), &extension)) { extensions_.Add(extension); } } @@ -46,18 +48,18 @@ void FeatureManager::AddCapability(SpvCapability cap) { spv_operand_desc desc = {}; if (SPV_SUCCESS == grammar_.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, cap, &desc)) { - libspirv::CapabilitySet(desc->numCapabilities, desc->capabilities) + CapabilitySet(desc->numCapabilities, desc->capabilities) .ForEach([this](SpvCapability c) { AddCapability(c); }); } } -void FeatureManager::AddCapabilities(ir::Module* module) { - for (ir::Instruction& inst : module->capabilities()) { +void FeatureManager::AddCapabilities(Module* module) { + for (Instruction& inst : module->capabilities()) { AddCapability(static_cast(inst.GetSingleWordInOperand(0))); } } -void FeatureManager::AddExtInstImportIds(ir::Module* module) { +void FeatureManager::AddExtInstImportIds(Module* module) { extinst_importid_GLSLstd450_ = module->GetExtInstImportId("GLSL.std.450"); } diff --git a/3rdparty/spirv-tools/source/opt/feature_manager.h b/3rdparty/spirv-tools/source/opt/feature_manager.h index b99a776fc..80b2cccf6 100644 --- a/3rdparty/spirv-tools/source/opt/feature_manager.h +++ b/3rdparty/spirv-tools/source/opt/feature_manager.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_FEATURE_MANAGER_H_ -#define LIBSPIRV_OPT_FEATURE_MANAGER_H_ +#ifndef SOURCE_OPT_FEATURE_MANAGER_H_ +#define SOURCE_OPT_FEATURE_MANAGER_H_ -#include "assembly_grammar.h" -#include "extensions.h" -#include "module.h" +#include "source/assembly_grammar.h" +#include "source/extensions.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { @@ -25,13 +25,10 @@ namespace opt { // Tracks features enabled by a module. The IRContext has a FeatureManager. class FeatureManager { public: - explicit FeatureManager(const libspirv::AssemblyGrammar& grammar) - : grammar_(grammar) {} + explicit FeatureManager(const AssemblyGrammar& grammar) : grammar_(grammar) {} // Returns true if |ext| is an enabled extension in the module. - bool HasExtension(libspirv::Extension ext) const { - return extensions_.Contains(ext); - } + bool HasExtension(Extension ext) const { return extensions_.Contains(ext); } // Returns true if |cap| is an enabled capability in the module. bool HasCapability(SpvCapability cap) const { @@ -39,12 +36,10 @@ class FeatureManager { } // Analyzes |module| and records enabled extensions and capabilities. - void Analyze(ir::Module* module); + void Analyze(Module* module); - libspirv::CapabilitySet* GetCapabilities() { return &capabilities_; } - const libspirv::CapabilitySet* GetCapabilities() const { - return &capabilities_; - } + CapabilitySet* GetCapabilities() { return &capabilities_; } + const CapabilitySet* GetCapabilities() const { return &capabilities_; } uint32_t GetExtInstImportId_GLSLstd450() const { return extinst_importid_GLSLstd450_; @@ -52,26 +47,26 @@ class FeatureManager { private: // Analyzes |module| and records enabled extensions. - void AddExtensions(ir::Module* module); + void AddExtensions(Module* module); // Adds the given |capability| and all implied capabilities into the current // FeatureManager. void AddCapability(SpvCapability capability); // Analyzes |module| and records enabled capabilities. - void AddCapabilities(ir::Module* module); + void AddCapabilities(Module* module); // Analyzes |module| and records imported external instruction sets. - void AddExtInstImportIds(ir::Module* module); + void AddExtInstImportIds(Module* module); // Auxiliary object for querying SPIR-V grammar facts. - const libspirv::AssemblyGrammar& grammar_; + const AssemblyGrammar& grammar_; // The enabled extensions. - libspirv::ExtensionSet extensions_; + ExtensionSet extensions_; // The enabled capabilities. - libspirv::CapabilitySet capabilities_; + CapabilitySet capabilities_; // Common external instruction import ids, cached for performance. uint32_t extinst_importid_GLSLstd450_ = 0; @@ -80,4 +75,4 @@ class FeatureManager { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_FEATURE_MANAGER_H_ +#endif // SOURCE_OPT_FEATURE_MANAGER_H_ diff --git a/3rdparty/spirv-tools/source/opt/flatten_decoration_pass.cpp b/3rdparty/spirv-tools/source/opt/flatten_decoration_pass.cpp index eac829733..f4de9116f 100644 --- a/3rdparty/spirv-tools/source/opt/flatten_decoration_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/flatten_decoration_pass.cpp @@ -12,26 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "flatten_decoration_pass.h" -#include "ir_context.h" +#include "source/opt/flatten_decoration_pass.h" #include +#include #include #include +#include #include +#include "source/opt/ir_context.h" + namespace spvtools { namespace opt { -using ir::Instruction; -using ir::Operand; - using Words = std::vector; using OrderedUsesMap = std::unordered_map; -Pass::Status FlattenDecorationPass::Process(ir::IRContext* c) { - InitializeProcessing(c); - +Pass::Status FlattenDecorationPass::Process() { bool modified = false; // The target Id of OpDecorationGroup instructions. diff --git a/3rdparty/spirv-tools/source/opt/flatten_decoration_pass.h b/3rdparty/spirv-tools/source/opt/flatten_decoration_pass.h index 7db6f8640..6a34f5bb2 100644 --- a/3rdparty/spirv-tools/source/opt/flatten_decoration_pass.h +++ b/3rdparty/spirv-tools/source/opt/flatten_decoration_pass.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_FLATTEN_DECORATION_PASS_H_ -#define LIBSPIRV_OPT_FLATTEN_DECORATION_PASS_H_ +#ifndef SOURCE_OPT_FLATTEN_DECORATION_PASS_H_ +#define SOURCE_OPT_FLATTEN_DECORATION_PASS_H_ -#include "ir_context.h" -#include "module.h" -#include "pass.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -25,11 +25,11 @@ namespace opt { // See optimizer.hpp for documentation. class FlattenDecorationPass : public Pass { public: - const char* name() const override { return "flatten-decoration"; } - Status Process(ir::IRContext*) override; + const char* name() const override { return "flatten-decorations"; } + Status Process() override; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_FLATTEN_DECORATION_PASS_H_ +#endif // SOURCE_OPT_FLATTEN_DECORATION_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/fold.cpp b/3rdparty/spirv-tools/source/opt/fold.cpp index 678c4566e..09d7e5122 100644 --- a/3rdparty/spirv-tools/source/opt/fold.cpp +++ b/3rdparty/spirv-tools/source/opt/fold.cpp @@ -12,21 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "fold.h" +#include "source/opt/fold.h" #include #include #include -#include "const_folding_rules.h" -#include "def_use_manager.h" -#include "folding_rules.h" -#include "ir_builder.h" -#include "ir_context.h" +#include "source/opt/const_folding_rules.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/folding_rules.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" namespace spvtools { namespace opt { - namespace { #ifndef INT32_MIN @@ -41,9 +40,9 @@ namespace { #define UINT32_MAX 0xffffffff /* 4294967295U */ #endif -// Returns the single-word result from performing the given unary operation on -// the operand value which is passed in as a 32-bit word. -uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) { +} // namespace + +uint32_t InstructionFolder::UnaryOperate(SpvOp opcode, uint32_t operand) const { switch (opcode) { // Arthimetics case SpvOp::SpvOpSNegate: @@ -59,9 +58,8 @@ uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) { } } -// Returns the single-word result from performing the given binary operation on -// the operand values which are passed in as two 32-bit word. -uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) { +uint32_t InstructionFolder::BinaryOperate(SpvOp opcode, uint32_t a, + uint32_t b) const { switch (opcode) { // Arthimetics case SpvOp::SpvOpIAdd: @@ -150,9 +148,8 @@ uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) { } } -// Returns the single-word result from performing the given ternary operation -// on the operand values which are passed in as three 32-bit word. -uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) { +uint32_t InstructionFolder::TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, + uint32_t c) const { switch (opcode) { case SpvOp::SpvOpSelect: return (static_cast(a)) ? b : c; @@ -163,12 +160,8 @@ uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) { } } -// Returns the single-word result from performing the given operation on the -// operand words. This only works with 32-bit operations and uses boolean -// convention that 0u is false, and anything else is boolean true. -// TODO(qining): Support operands other than 32-bit wide. -uint32_t OperateWords(SpvOp opcode, - const std::vector& operand_words) { +uint32_t InstructionFolder::OperateWords( + SpvOp opcode, const std::vector& operand_words) const { switch (operand_words.size()) { case 1: return UnaryOperate(opcode, operand_words.front()); @@ -183,10 +176,9 @@ uint32_t OperateWords(SpvOp opcode, } } -bool FoldInstructionInternal(ir::Instruction* inst) { - ir::IRContext* context = inst->context(); +bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const { auto identity_map = [](uint32_t id) { return id; }; - ir::Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map); + Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map); if (folded_inst != nullptr) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}}); @@ -194,33 +186,26 @@ bool FoldInstructionInternal(ir::Instruction* inst) { } SpvOp opcode = inst->opcode(); - analysis::ConstantManager* const_manager = context->get_constant_mgr(); + analysis::ConstantManager* const_manager = context_->get_constant_mgr(); std::vector constants = const_manager->GetOperandConstants(inst); - static FoldingRules* rules = new FoldingRules(); - for (FoldingRule rule : rules->GetRulesForOpcode(opcode)) { - if (rule(inst, constants)) { + for (const FoldingRule& rule : GetFoldingRules().GetRulesForOpcode(opcode)) { + if (rule(context_, inst, constants)) { return true; } } return false; } -} // namespace - -const ConstantFoldingRules& GetConstantFoldingRules() { - static ConstantFoldingRules* rules = new ConstantFoldingRules(); - return *rules; -} - // Returns the result of performing an operation on scalar constant operands. // This function extracts the operand values as 32 bit words and returns the // result in 32 bit word. Scalar constants with longer than 32-bit width are // not accepted in this function. -uint32_t FoldScalars(SpvOp opcode, - const std::vector& operands) { +uint32_t InstructionFolder::FoldScalars( + SpvOp opcode, + const std::vector& operands) const { assert(IsFoldableOpcode(opcode) && "Unhandled instruction opcode in FoldScalars"); std::vector operand_values_in_raw_words; @@ -242,23 +227,16 @@ uint32_t FoldScalars(SpvOp opcode, return OperateWords(opcode, operand_values_in_raw_words); } -// Returns true if |inst| is a binary operation that takes two integers as -// parameters and folds to a constant that can be represented as an unsigned -// 32-bit value when the ids have been replaced by |id_map|. If |inst| can be -// folded, the resulting value is returned in |*result|. Valid result types for -// the instruction are any integer (signed or unsigned) with 32-bits or less, or -// a boolean value. -bool FoldBinaryIntegerOpToConstant(ir::Instruction* inst, - std::function id_map, - uint32_t* result) { +bool InstructionFolder::FoldBinaryIntegerOpToConstant( + Instruction* inst, const std::function& id_map, + uint32_t* result) const { SpvOp opcode = inst->opcode(); - ir::IRContext* context = inst->context(); - analysis::ConstantManager* const_manger = context->get_constant_mgr(); + analysis::ConstantManager* const_manger = context_->get_constant_mgr(); uint32_t ids[2]; const analysis::IntConstant* constants[2]; for (uint32_t i = 0; i < 2; i++) { - const ir::Operand* operand = &inst->GetInOperand(i); + const Operand* operand = &inst->GetInOperand(i); if (operand->type != SPV_OPERAND_TYPE_ID) { return false; } @@ -432,20 +410,16 @@ bool FoldBinaryIntegerOpToConstant(ir::Instruction* inst, return false; } -// Returns true if |inst| is a binary operation on two boolean values, and folds -// to a constant boolean value when the ids have been replaced using |id_map|. -// If |inst| can be folded, the result value is returned in |*result|. -bool FoldBinaryBooleanOpToConstant(ir::Instruction* inst, - std::function id_map, - uint32_t* result) { +bool InstructionFolder::FoldBinaryBooleanOpToConstant( + Instruction* inst, const std::function& id_map, + uint32_t* result) const { SpvOp opcode = inst->opcode(); - ir::IRContext* context = inst->context(); - analysis::ConstantManager* const_manger = context->get_constant_mgr(); + analysis::ConstantManager* const_manger = context_->get_constant_mgr(); uint32_t ids[2]; const analysis::BoolConstant* constants[2]; for (uint32_t i = 0; i < 2; i++) { - const ir::Operand* operand = &inst->GetInOperand(i); + const Operand* operand = &inst->GetInOperand(i); if (operand->type != SPV_OPERAND_TYPE_ID) { return false; } @@ -484,13 +458,9 @@ bool FoldBinaryBooleanOpToConstant(ir::Instruction* inst, return false; } -// Returns true if |inst| can be folded to an constant when the ids have been -// substituted using id_map. If it can, the value is returned in |result|. If -// not, |result| is unchanged. It is assumed that not all operands are -// constant. Those cases are handled by |FoldScalar|. -bool FoldIntegerOpToConstant(ir::Instruction* inst, - std::function id_map, - uint32_t* result) { +bool InstructionFolder::FoldIntegerOpToConstant( + Instruction* inst, const std::function& id_map, + uint32_t* result) const { assert(IsFoldableOpcode(inst->opcode()) && "Unhandled instruction opcode in FoldScalars"); switch (inst->NumInOperands()) { @@ -502,9 +472,9 @@ bool FoldIntegerOpToConstant(ir::Instruction* inst, } } -std::vector FoldVectors( +std::vector InstructionFolder::FoldVectors( SpvOp opcode, uint32_t num_dims, - const std::vector& operands) { + const std::vector& operands) const { assert(IsFoldableOpcode(opcode) && "Unhandled instruction opcode in FoldVectors"); std::vector result; @@ -547,7 +517,7 @@ std::vector FoldVectors( return result; } -bool IsFoldableOpcode(SpvOp opcode) { +bool InstructionFolder::IsFoldableOpcode(SpvOp opcode) const { // NOTE: Extend to more opcodes as new cases are handled in the folder // functions. switch (opcode) { @@ -589,7 +559,8 @@ bool IsFoldableOpcode(SpvOp opcode) { } } -bool IsFoldableConstant(const analysis::Constant* cst) { +bool InstructionFolder::IsFoldableConstant( + const analysis::Constant* cst) const { // Currently supported constants are 32-bit values or null constants. if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant()) return scalar->words().size() == 1; @@ -597,10 +568,9 @@ bool IsFoldableConstant(const analysis::Constant* cst) { return cst->AsNullConstant() != nullptr; } -ir::Instruction* FoldInstructionToConstant( - ir::Instruction* inst, std::function id_map) { - ir::IRContext* context = inst->context(); - analysis::ConstantManager* const_mgr = context->get_constant_mgr(); +Instruction* InstructionFolder::FoldInstructionToConstant( + Instruction* inst, std::function id_map) const { + analysis::ConstantManager* const_mgr = context_->get_constant_mgr(); if (!inst->IsFoldableByFoldScalar() && !GetConstantFoldingRules().HasFoldingRule(inst->opcode())) { @@ -625,12 +595,13 @@ ir::Instruction* FoldInstructionToConstant( const analysis::Constant* folded_const = nullptr; for (auto rule : GetConstantFoldingRules().GetRulesForOpcode(inst->opcode())) { - folded_const = rule(inst, constants); + folded_const = rule(context_, inst, constants); if (folded_const != nullptr) { - ir::Instruction* const_inst = - const_mgr->GetDefiningInstruction(folded_const); + Instruction* const_inst = + const_mgr->GetDefiningInstruction(folded_const, inst->type_id()); + assert(const_inst->type_id() == inst->type_id()); // May be a new instruction that needs to be analysed. - context->UpdateDefUse(const_inst); + context_->UpdateDefUse(const_inst); return const_inst; } } @@ -651,12 +622,14 @@ ir::Instruction* FoldInstructionToConstant( if (successful) { const analysis::Constant* result_const = const_mgr->GetConstant(const_mgr->GetType(inst), {result_val}); - return const_mgr->GetDefiningInstruction(result_const); + Instruction* folded_inst = + const_mgr->GetDefiningInstruction(result_const, inst->type_id()); + return folded_inst; } return nullptr; } -bool IsFoldableType(ir::Instruction* type_inst) { +bool InstructionFolder::IsFoldableType(Instruction* type_inst) const { // Support 32-bit integers. if (type_inst->opcode() == SpvOpTypeInt) { return type_inst->GetSingleWordInOperand(0) == 32; @@ -669,9 +642,9 @@ bool IsFoldableType(ir::Instruction* type_inst) { return false; } -bool FoldInstruction(ir::Instruction* inst) { +bool InstructionFolder::FoldInstruction(Instruction* inst) const { bool modified = false; - ir::Instruction* folded_inst(inst); + Instruction* folded_inst(inst); while (folded_inst->opcode() != SpvOpCopyObject && FoldInstructionInternal(&*folded_inst)) { modified = true; diff --git a/3rdparty/spirv-tools/source/opt/fold.h b/3rdparty/spirv-tools/source/opt/fold.h index 9c6028dfb..0dc7c0ebb 100644 --- a/3rdparty/spirv-tools/source/opt/fold.h +++ b/3rdparty/spirv-tools/source/opt/fold.h @@ -12,87 +12,160 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_UTIL_FOLD_H_ -#define LIBSPIRV_UTIL_FOLD_H_ +#ifndef SOURCE_OPT_FOLD_H_ +#define SOURCE_OPT_FOLD_H_ #include #include -#include "const_folding_rules.h" -#include "constants.h" -#include "def_use_manager.h" +#include "source/opt/const_folding_rules.h" +#include "source/opt/constants.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/folding_rules.h" namespace spvtools { namespace opt { -// Returns a reference to the ConstnatFoldingRules instance. -const ConstantFoldingRules& GetConstantFoldingRules(); +class InstructionFolder { + public: + explicit InstructionFolder(IRContext* context) : context_(context) {} -// Returns the result of folding a scalar instruction with the given |opcode| -// and |operands|. Each entry in |operands| is a pointer to an -// analysis::Constant instance, which should've been created with the constant -// manager (See IRContext::get_constant_mgr). -// -// It is an error to call this function with an opcode that does not pass the -// IsFoldableOpcode test. If any error occurs during folding, the folder will -// faill with a call to assert. -uint32_t FoldScalars(SpvOp opcode, - const std::vector& operands); + // Returns the result of folding a scalar instruction with the given |opcode| + // and |operands|. Each entry in |operands| is a pointer to an + // analysis::Constant instance, which should've been created with the constant + // manager (See IRContext::get_constant_mgr). + // + // It is an error to call this function with an opcode that does not pass the + // IsFoldableOpcode test. If any error occurs during folding, the folder will + // fail with a call to assert. + uint32_t FoldScalars( + SpvOp opcode, + const std::vector& operands) const; -// Returns the result of performing an operation with the given |opcode| over -// constant vectors with |num_dims| dimensions. Each entry in |operands| is a -// pointer to an analysis::Constant instance, which should've been created with -// the constant manager (See IRContext::get_constant_mgr). -// -// This function iterates through the given vector type constant operands and -// calculates the result for each element of the result vector to return. -// Vectors with longer than 32-bit scalar components are not accepted in this -// function. -// -// It is an error to call this function with an opcode that does not pass the -// IsFoldableOpcode test. If any error occurs during folding, the folder will -// faill with a call to assert. -std::vector FoldVectors( - SpvOp opcode, uint32_t num_dims, - const std::vector& operands); + // Returns the result of performing an operation with the given |opcode| over + // constant vectors with |num_dims| dimensions. Each entry in |operands| is a + // pointer to an analysis::Constant instance, which should've been created + // with the constant manager (See IRContext::get_constant_mgr). + // + // This function iterates through the given vector type constant operands and + // calculates the result for each element of the result vector to return. + // Vectors with longer than 32-bit scalar components are not accepted in this + // function. + // + // It is an error to call this function with an opcode that does not pass the + // IsFoldableOpcode test. If any error occurs during folding, the folder will + // fail with a call to assert. + std::vector FoldVectors( + SpvOp opcode, uint32_t num_dims, + const std::vector& operands) const; -// Returns true if |opcode| represents an operation handled by FoldScalars or -// FoldVectors. -bool IsFoldableOpcode(SpvOp opcode); + // Returns true if |opcode| represents an operation handled by FoldScalars or + // FoldVectors. + bool IsFoldableOpcode(SpvOp opcode) const; -// Returns true if |cst| is supported by FoldScalars and FoldVectors. -bool IsFoldableConstant(const analysis::Constant* cst); + // Returns true if |cst| is supported by FoldScalars and FoldVectors. + bool IsFoldableConstant(const analysis::Constant* cst) const; -// Returns true if |FoldInstructionToConstant| could fold an instruction whose -// result type is |type_inst|. -bool IsFoldableType(ir::Instruction* type_inst); + // Returns true if |FoldInstructionToConstant| could fold an instruction whose + // result type is |type_inst|. + bool IsFoldableType(Instruction* type_inst) const; -// Tries to fold |inst| to a single constant, when the input ids to |inst| have -// been substituted using |id_map|. Returns a pointer to the OpConstant* -// instruction if successful. If necessary, a new constant instruction is -// created and placed in the global values section. -// -// |id_map| is a function that takes one result id and returns another. It can -// be used for things like CCP where it is known that some ids contain a -// constant, but the instruction itself has not been updated yet. This can map -// those ids to the appropriate constants. -ir::Instruction* FoldInstructionToConstant( - ir::Instruction* inst, std::function id_map); + // Tries to fold |inst| to a single constant, when the input ids to |inst| + // have been substituted using |id_map|. Returns a pointer to the OpConstant* + // instruction if successful. If necessary, a new constant instruction is + // created and placed in the global values section. + // + // |id_map| is a function that takes one result id and returns another. It + // can be used for things like CCP where it is known that some ids contain a + // constant, but the instruction itself has not been updated yet. This can + // map those ids to the appropriate constants. + Instruction* FoldInstructionToConstant( + Instruction* inst, std::function id_map) const; + // Returns true if |inst| can be folded into a simpler instruction. + // If |inst| can be simplified, |inst| is overwritten with the simplified + // instruction reusing the same result id. + // + // If |inst| is simplified, it is possible that the resulting code in invalid + // because the instruction is in a bad location. Callers of this function + // have to handle the following cases: + // + // 1) An OpPhi becomes and OpCopyObject - If there are OpPhi instruction after + // |inst| in a basic block then this is invalid. The caller must fix this + // up. + bool FoldInstruction(Instruction* inst) const; -// Returns true if |inst| can be folded into a simpler instruction. -// If |inst| can be simplified, |inst| is overwritten with the simplified -// instruction reusing the same result id. -// -// If |inst| is simplified, it is possible that the resulting code in invalid -// because the instruction is in a bad location. Callers of this function have -// to handle the following cases: -// -// 1) An OpPhi becomes and OpCopyObject - If there are OpPhi instruction after -// |inst| in a basic block then this is invalid. The caller must fix this -// up. -bool FoldInstruction(ir::Instruction* inst); + // Return true if this opcode has a const folding rule associtated with it. + bool HasConstFoldingRule(SpvOp opcode) const { + return GetConstantFoldingRules().HasFoldingRule(opcode); + } + + private: + // Returns a reference to the ConstnatFoldingRules instance. + const ConstantFoldingRules& GetConstantFoldingRules() const { + return const_folding_rules; + } + + // Returns a reference to the FoldingRules instance. + const FoldingRules& GetFoldingRules() const { return folding_rules; } + + // Returns the single-word result from performing the given unary operation on + // the operand value which is passed in as a 32-bit word. + uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) const; + + // Returns the single-word result from performing the given binary operation + // on the operand values which are passed in as two 32-bit word. + uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) const; + + // Returns the single-word result from performing the given ternary operation + // on the operand values which are passed in as three 32-bit word. + uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, + uint32_t c) const; + + // Returns the single-word result from performing the given operation on the + // operand words. This only works with 32-bit operations and uses boolean + // convention that 0u is false, and anything else is boolean true. + // TODO(qining): Support operands other than 32-bit wide. + uint32_t OperateWords(SpvOp opcode, + const std::vector& operand_words) const; + + bool FoldInstructionInternal(Instruction* inst) const; + + // Returns true if |inst| is a binary operation that takes two integers as + // parameters and folds to a constant that can be represented as an unsigned + // 32-bit value when the ids have been replaced by |id_map|. If |inst| can be + // folded, the resulting value is returned in |*result|. Valid result types + // for the instruction are any integer (signed or unsigned) with 32-bits or + // less, or a boolean value. + bool FoldBinaryIntegerOpToConstant( + Instruction* inst, const std::function& id_map, + uint32_t* result) const; + + // Returns true if |inst| is a binary operation on two boolean values, and + // folds + // to a constant boolean value when the ids have been replaced using |id_map|. + // If |inst| can be folded, the result value is returned in |*result|. + bool FoldBinaryBooleanOpToConstant( + Instruction* inst, const std::function& id_map, + uint32_t* result) const; + + // Returns true if |inst| can be folded to an constant when the ids have been + // substituted using id_map. If it can, the value is returned in |result|. If + // not, |result| is unchanged. It is assumed that not all operands are + // constant. Those cases are handled by |FoldScalar|. + bool FoldIntegerOpToConstant(Instruction* inst, + const std::function& id_map, + uint32_t* result) const; + + IRContext* context_; + + // Folding rules used by |FoldInstructionToConstant| and |FoldInstruction|. + ConstantFoldingRules const_folding_rules; + + // Folding rules used by |FoldInstruction|. + FoldingRules folding_rules; +}; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_UTIL_FOLD_H_ +#endif // SOURCE_OPT_FOLD_H_ diff --git a/3rdparty/spirv-tools/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/3rdparty/spirv-tools/source/opt/fold_spec_constant_op_and_composite_pass.cpp index 79eaeade2..663d112d4 100644 --- a/3rdparty/spirv-tools/source/opt/fold_spec_constant_op_and_composite_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/fold_spec_constant_op_and_composite_pass.cpp @@ -12,32 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "fold_spec_constant_op_and_composite_pass.h" +#include "source/opt/fold_spec_constant_op_and_composite_pass.h" #include #include #include -#include "constants.h" -#include "fold.h" -#include "ir_context.h" -#include "make_unique.h" +#include "source/opt/constants.h" +#include "source/opt/fold.h" +#include "source/opt/ir_context.h" +#include "source/util/make_unique.h" namespace spvtools { namespace opt { -Pass::Status FoldSpecConstantOpAndCompositePass::Process( - ir::IRContext* irContext) { - Initialize(irContext); - return ProcessImpl(irContext); -} - -void FoldSpecConstantOpAndCompositePass::Initialize(ir::IRContext* irContext) { - InitializeProcessing(irContext); -} - -Pass::Status FoldSpecConstantOpAndCompositePass::ProcessImpl( - ir::IRContext* irContext) { +Pass::Status FoldSpecConstantOpAndCompositePass::Process() { bool modified = false; // Traverse through all the constant defining instructions. For Normal // Constants whose values are determined and do not depend on OpUndef @@ -59,13 +48,13 @@ Pass::Status FoldSpecConstantOpAndCompositePass::ProcessImpl( // the dependee Spec Constants, all its dependent constants must have been // processed and all its dependent Spec Constants should have been folded if // possible. - ir::Module::inst_iterator next_inst = irContext->types_values_begin(); - for (ir::Module::inst_iterator inst_iter = next_inst; + Module::inst_iterator next_inst = context()->types_values_begin(); + for (Module::inst_iterator inst_iter = next_inst; // Need to re-evaluate the end iterator since we may modify the list of // instructions in this section of the module as the process goes. - inst_iter != irContext->types_values_end(); inst_iter = next_inst) { + inst_iter != context()->types_values_end(); inst_iter = next_inst) { ++next_inst; - ir::Instruction* inst = &*inst_iter; + Instruction* inst = &*inst_iter; // Collect constant values of normal constants and process the // OpSpecConstantOp and OpSpecConstantComposite instructions if possible. // The constant values will be stored in analysis::Constant instances. @@ -121,9 +110,9 @@ Pass::Status FoldSpecConstantOpAndCompositePass::ProcessImpl( } bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp( - ir::Module::inst_iterator* pos) { - ir::Instruction* inst = &**pos; - ir::Instruction* folded_inst = nullptr; + Module::inst_iterator* pos) { + Instruction* inst = &**pos; + Instruction* folded_inst = nullptr; assert(inst->GetInOperand(0).type == SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER && "The first in-operand of OpSpecContantOp instruction must be of " @@ -161,16 +150,16 @@ bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp( uint32_t FoldSpecConstantOpAndCompositePass::GetTypeComponent( uint32_t typeId, uint32_t element) const { - ir::Instruction* type = context()->get_def_use_mgr()->GetDef(typeId); + Instruction* type = context()->get_def_use_mgr()->GetDef(typeId); uint32_t subtype = type->GetTypeComponent(element); assert(subtype != 0); return subtype; } -ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract( - ir::Module::inst_iterator* pos) { - ir::Instruction* inst = &**pos; +Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract( + Module::inst_iterator* pos) { + Instruction* inst = &**pos; assert(inst->NumInOperands() - 1 >= 2 && "OpSpecConstantOp CompositeExtract requires at least two non-type " "non-opcode operands."); @@ -218,9 +207,9 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract( current_const, pos); } -ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle( - ir::Module::inst_iterator* pos) { - ir::Instruction* inst = &**pos; +Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle( + Module::inst_iterator* pos) { + Instruction* inst = &**pos; analysis::Vector* result_vec_type = context()->get_constant_mgr()->GetType(inst)->AsVector(); assert(inst->NumInOperands() - 1 > 2 && @@ -290,11 +279,10 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle( "Literal index out of bound of the concatenated vector"); selected_components.push_back(concatenated_components[literal]); } - auto new_vec_const = - new analysis::VectorConstant(result_vec_type, selected_components); + auto new_vec_const = MakeUnique( + result_vec_type, selected_components); auto reg_vec_const = - context()->get_constant_mgr()->RegisterConstant(new_vec_const); - if (reg_vec_const != new_vec_const) delete new_vec_const; + context()->get_constant_mgr()->RegisterConstant(std::move(new_vec_const)); return context()->get_constant_mgr()->BuildInstructionAndAddToModule( reg_vec_const, pos); } @@ -313,9 +301,9 @@ bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) { } else if (auto* it = type->AsInteger()) { if (it->width() == 32) return true; } else if (auto* vt = type->AsVector()) { - if (vt->element_type()->AsBool()) + if (vt->element_type()->AsBool()) { return true; - else if (auto* vit = vt->element_type()->AsInteger()) { + } else if (auto* vit = vt->element_type()->AsInteger()) { if (vit->width() == 32) return true; } } @@ -323,9 +311,9 @@ bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) { } } // namespace -ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( - ir::Module::inst_iterator* pos) { - const ir::Instruction* inst = &**pos; +Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( + Module::inst_iterator* pos) { + const Instruction* inst = &**pos; const analysis::Type* result_type = context()->get_constant_mgr()->GetType(inst); SpvOp spec_opcode = static_cast(inst->GetSingleWordInOperand(0)); @@ -333,8 +321,7 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( std::vector operands; if (!std::all_of( - inst->cbegin(), inst->cend(), - [&operands, this](const ir::Operand& o) { + inst->cbegin(), inst->cend(), [&operands, this](const Operand& o) { // skip the operands that is not an id. if (o.type != spv_operand_type_t::SPV_OPERAND_TYPE_ID) return true; uint32_t id = o.words.front(); @@ -351,7 +338,8 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( if (result_type->AsInteger() || result_type->AsBool()) { // Scalar operation - uint32_t result_val = FoldScalars(spec_opcode, operands); + uint32_t result_val = + context()->get_instruction_folder().FoldScalars(spec_opcode, operands); auto result_const = context()->get_constant_mgr()->GetConstant(result_type, {result_val}); return context()->get_constant_mgr()->BuildInstructionAndAddToModule( @@ -362,7 +350,8 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( result_type->AsVector()->element_type(); uint32_t num_dims = result_type->AsVector()->element_count(); std::vector result_vec = - FoldVectors(spec_opcode, num_dims, operands); + context()->get_instruction_folder().FoldVectors(spec_opcode, num_dims, + operands); std::vector result_vector_components; for (uint32_t r : result_vec) { if (auto rc = @@ -378,11 +367,10 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( assert(false && "Failed to create constants with 32-bit word"); } } - auto new_vec_const = new analysis::VectorConstant(result_type->AsVector(), - result_vector_components); - auto reg_vec_const = - context()->get_constant_mgr()->RegisterConstant(new_vec_const); - if (reg_vec_const != new_vec_const) delete new_vec_const; + auto new_vec_const = MakeUnique( + result_type->AsVector(), result_vector_components); + auto reg_vec_const = context()->get_constant_mgr()->RegisterConstant( + std::move(new_vec_const)); return context()->get_constant_mgr()->BuildInstructionAndAddToModule( reg_vec_const, pos); } else { diff --git a/3rdparty/spirv-tools/source/opt/fold_spec_constant_op_and_composite_pass.h b/3rdparty/spirv-tools/source/opt/fold_spec_constant_op_and_composite_pass.h index 5f901eeb1..16271251f 100644 --- a/3rdparty/spirv-tools/source/opt/fold_spec_constant_op_and_composite_pass.h +++ b/3rdparty/spirv-tools/source/opt/fold_spec_constant_op_and_composite_pass.h @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_FOLD_SPEC_CONSTANT_OP_AND_COMPOSITE_PASS_H_ -#define LIBSPIRV_OPT_FOLD_SPEC_CONSTANT_OP_AND_COMPOSITE_PASS_H_ +#ifndef SOURCE_OPT_FOLD_SPEC_CONSTANT_OP_AND_COMPOSITE_PASS_H_ +#define SOURCE_OPT_FOLD_SPEC_CONSTANT_OP_AND_COMPOSITE_PASS_H_ #include #include #include -#include "constants.h" -#include "def_use_manager.h" -#include "ir_context.h" -#include "module.h" -#include "pass.h" -#include "type_manager.h" +#include "source/opt/constants.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" +#include "source/opt/type_manager.h" namespace spvtools { namespace opt { @@ -36,20 +36,13 @@ class FoldSpecConstantOpAndCompositePass : public Pass { const char* name() const override { return "fold-spec-const-op-composite"; } - Status Process(ir::IRContext* irContext) override; + // Iterates through the types-constants-globals section of the given module, + // finds the Spec Constants defined with OpSpecConstantOp and + // OpSpecConstantComposite instructions. If the result value of those spec + // constants can be folded, fold them to their corresponding normal constants. + Status Process() override; private: - // Initializes the type manager, def-use manager and get the maximal id used - // in the module. - void Initialize(ir::IRContext* irContext); - - // The real entry of processing. Iterates through the types-constants-globals - // section of the given module, finds the Spec Constants defined with - // OpSpecConstantOp and OpSpecConstantComposite instructions. If the result - // value of those spec constants can be folded, fold them to their - // corresponding normal constants. - Status ProcessImpl(ir::IRContext* irContext); - // Processes the OpSpecConstantOp instruction pointed by the given // instruction iterator, folds it to normal constants if possible. Returns // true if the spec constant is folded to normal constants. New instructions @@ -59,26 +52,25 @@ class FoldSpecConstantOpAndCompositePass : public Pass { // folding is done successfully, the original OpSpecConstantOp instruction // will be changed to Nop and new folded instruction will be inserted before // it. - bool ProcessOpSpecConstantOp(ir::Module::inst_iterator* pos); + bool ProcessOpSpecConstantOp(Module::inst_iterator* pos); // Try to fold the OpSpecConstantOp CompositeExtract instruction pointed by // the given instruction iterator to a normal constant defining instruction. // Returns the pointer to the new constant defining instruction if succeeded. // Otherwise returns nullptr. - ir::Instruction* DoCompositeExtract(ir::Module::inst_iterator* inst_iter_ptr); + Instruction* DoCompositeExtract(Module::inst_iterator* inst_iter_ptr); // Try to fold the OpSpecConstantOp VectorShuffle instruction pointed by the // given instruction iterator to a normal constant defining instruction. // Returns the pointer to the new constant defining instruction if succeeded. // Otherwise return nullptr. - ir::Instruction* DoVectorShuffle(ir::Module::inst_iterator* inst_iter_ptr); + Instruction* DoVectorShuffle(Module::inst_iterator* inst_iter_ptr); // Try to fold the OpSpecConstantOp instruction // pointed by the given instruction iterator to a normal constant defining // instruction. Returns the pointer to the new constant defining instruction // if succeeded, otherwise return nullptr. - ir::Instruction* DoComponentWiseOperation( - ir::Module::inst_iterator* inst_iter_ptr); + Instruction* DoComponentWiseOperation(Module::inst_iterator* inst_iter_ptr); // Returns the |element|'th subtype of |type|. // @@ -89,4 +81,4 @@ class FoldSpecConstantOpAndCompositePass : public Pass { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_FOLD_SPEC_CONSTANT_OP_AND_COMPOSITE_PASS_H_ +#endif // SOURCE_OPT_FOLD_SPEC_CONSTANT_OP_AND_COMPOSITE_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/folding_rules.cpp b/3rdparty/spirv-tools/source/opt/folding_rules.cpp index 8e6fba115..c64cedfb1 100644 --- a/3rdparty/spirv-tools/source/opt/folding_rules.cpp +++ b/3rdparty/spirv-tools/source/opt/folding_rules.cpp @@ -12,13 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "folding_rules.h" -#include "latest_version_glsl_std_450_header.h" +#include "source/opt/folding_rules.h" + +#include +#include +#include + +#include "source/latest_version_glsl_std_450_header.h" +#include "source/opt/ir_context.h" namespace spvtools { namespace opt { - namespace { + const uint32_t kExtractCompositeIdInIdx = 0; const uint32_t kInsertObjectIdInIdx = 0; const uint32_t kInsertCompositeIdInIdx = 1; @@ -26,6 +32,8 @@ const uint32_t kExtInstSetIdInIdx = 0; const uint32_t kExtInstInstructionInIdx = 1; const uint32_t kFMixXIdInIdx = 2; const uint32_t kFMixYIdInIdx = 3; +const uint32_t kFMixAIdInIdx = 4; +const uint32_t kStoreObjectInIdx = 1; // Returns the element width of |type|. uint32_t ElementWidth(const analysis::Type* type) { @@ -69,9 +77,8 @@ const analysis::Constant* ConstInput( return constants[0] ? constants[0] : constants[1]; } -ir::Instruction* NonConstInput(ir::IRContext* context, - const analysis::Constant* c, - ir::Instruction* inst) { +Instruction* NonConstInput(IRContext* context, const analysis::Constant* c, + Instruction* inst) { uint32_t in_op = c ? 1u : 0u; return context->get_def_use_mgr()->GetDef( inst->GetSingleWordInOperand(in_op)); @@ -87,10 +94,10 @@ uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr, assert(width == 32 || width == 64); std::vector words; if (width == 64) { - spvutils::FloatProxy result(c->GetDouble() * -1.0); + utils::FloatProxy result(c->GetDouble() * -1.0); words = result.GetWords(); } else { - spvutils::FloatProxy result(c->GetFloat() * -1.0f); + utils::FloatProxy result(c->GetFloat() * -1.0f); words = result.GetWords(); } @@ -177,11 +184,11 @@ uint32_t Reciprocal(analysis::ConstantManager* const_mgr, assert(width == 32 || width == 64); std::vector words; if (width == 64) { - spvutils::FloatProxy result(1.0 / c->GetDouble()); + spvtools::utils::FloatProxy result(1.0 / c->GetDouble()); if (!IsValidResult(result.getAsFloat())) return 0; words = result.GetWords(); } else { - spvutils::FloatProxy result(1.0f / c->GetFloat()); + spvtools::utils::FloatProxy result(1.0f / c->GetFloat()); if (!IsValidResult(result.getAsFloat())) return 0; words = result.GetWords(); } @@ -193,10 +200,9 @@ uint32_t Reciprocal(analysis::ConstantManager* const_mgr, // Replaces fdiv where second operand is constant with fmul. FoldingRule ReciprocalFDiv() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFDiv); - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -238,17 +244,16 @@ FoldingRule ReciprocalFDiv() { // Elides consecutive negate instructions. FoldingRule MergeNegateArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate); (void)constants; - ir::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) return false; - ir::Instruction* op_inst = + Instruction* op_inst = context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) return false; @@ -273,18 +278,17 @@ FoldingRule MergeNegateArithmetic() { // -(x / 2) = x / -2 // -(2 / x) = -2 / x FoldingRule MergeNegateMulDivArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate); (void)constants; - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) return false; - ir::Instruction* op_inst = + Instruction* op_inst = context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) return false; @@ -332,18 +336,17 @@ FoldingRule MergeNegateMulDivArithmetic() { // -(x - 2) = 2 - x // -(2 - x) = x - 2 FoldingRule MergeNegateAddSubArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate); (void)constants; - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) return false; - ir::Instruction* op_inst = + Instruction* op_inst = context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) return false; @@ -415,19 +418,18 @@ uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr, uint32_t width = type->AsFloat()->width(); assert(width == 32 || width == 64); std::vector words; -#define FOLD_OP(op) \ - if (width == 64) { \ - spvutils::FloatProxy val = \ - input1->GetDouble() op input2->GetDouble(); \ - double dval = val.getAsFloat(); \ - if (!IsValidResult(dval)) return 0; \ - words = val.GetWords(); \ - } else { \ - spvutils::FloatProxy val = \ - input1->GetFloat() op input2->GetFloat(); \ - float fval = val.getAsFloat(); \ - if (!IsValidResult(fval)) return 0; \ - words = val.GetWords(); \ +#define FOLD_OP(op) \ + if (width == 64) { \ + utils::FloatProxy val = \ + input1->GetDouble() op input2->GetDouble(); \ + double dval = val.getAsFloat(); \ + if (!IsValidResult(dval)) return 0; \ + words = val.GetWords(); \ + } else { \ + utils::FloatProxy val = input1->GetFloat() op input2->GetFloat(); \ + float fval = val.getAsFloat(); \ + if (!IsValidResult(fval)) return 0; \ + words = val.GetWords(); \ } switch (opcode) { case SpvOpFMul: @@ -566,10 +568,9 @@ uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode, // (x * 2) * 2 = x * 4 // (2 * x) * 2 = x * 4 FoldingRule MergeMulMulArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul); - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -582,7 +583,7 @@ FoldingRule MergeMulMulArithmetic() { // Determine the constant input and the variable input in |inst|. const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; - ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); + Instruction* other_inst = NonConstInput(context, constants[0], inst); if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed()) return false; @@ -616,12 +617,15 @@ FoldingRule MergeMulMulArithmetic() { // 2 * (2 / x) = 4 / x // (x / 2) * 2 = x * 1 // (2 / x) * 2 = 4 / x +// (y / x) * x = y +// x * (y / x) = y FoldingRule MergeMulDivArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFMul); - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (!inst->IsFloatingPointFoldingAllowed()) return false; @@ -629,9 +633,23 @@ FoldingRule MergeMulDivArithmetic() { uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; + for (uint32_t i = 0; i < 2; i++) { + uint32_t op_id = inst->GetSingleWordInOperand(i); + Instruction* op_inst = def_use_mgr->GetDef(op_id); + if (op_inst->opcode() == SpvOpFDiv) { + if (op_inst->GetSingleWordInOperand(1) == + inst->GetSingleWordInOperand(1 - i)) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}}); + return true; + } + } + } + const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; - ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); + Instruction* other_inst = NonConstInput(context, constants[0], inst); if (!other_inst->IsFloatingPointFoldingAllowed()) return false; if (other_inst->opcode() == SpvOpFDiv) { @@ -676,10 +694,9 @@ FoldingRule MergeMulDivArithmetic() { // (-x) * 2 = x * -2 // 2 * (-x) = x * -2 FoldingRule MergeMulNegateArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul); - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -691,7 +708,7 @@ FoldingRule MergeMulNegateArithmetic() { const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; - ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); + Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; @@ -717,10 +734,9 @@ FoldingRule MergeMulNegateArithmetic() { // (4 / x) / 2 = 2 / x // (x / 2) / 2 = x / 4 FoldingRule MergeDivDivArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFDiv); - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -731,7 +747,7 @@ FoldingRule MergeDivDivArithmetic() { const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1 || HasZero(const_input1)) return false; - ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); + Instruction* other_inst = NonConstInput(context, constants[0], inst); if (!other_inst->IsFloatingPointFoldingAllowed()) return false; bool first_is_variable = constants[0] == nullptr; @@ -786,12 +802,15 @@ FoldingRule MergeDivDivArithmetic() { // 4 / (2 * x) = 2 / x // (x * 4) / 2 = x * 2 // (4 * x) / 2 = x * 2 +// (x * y) / x = y +// (y * x) / x = y FoldingRule MergeDivMulArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFDiv); - ir::IRContext* context = inst->context(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (!inst->IsFloatingPointFoldingAllowed()) return false; @@ -799,9 +818,24 @@ FoldingRule MergeDivMulArithmetic() { uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; + uint32_t op_id = inst->GetSingleWordInOperand(0); + Instruction* op_inst = def_use_mgr->GetDef(op_id); + + if (op_inst->opcode() == SpvOpFMul) { + for (uint32_t i = 0; i < 2; i++) { + if (op_inst->GetSingleWordInOperand(i) == + inst->GetSingleWordInOperand(1)) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, + {op_inst->GetSingleWordInOperand(1 - i)}}}); + return true; + } + } + } + const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1 || HasZero(const_input1)) return false; - ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); + Instruction* other_inst = NonConstInput(context, constants[0], inst); if (!other_inst->IsFloatingPointFoldingAllowed()) return false; bool first_is_variable = constants[0] == nullptr; @@ -843,11 +877,10 @@ FoldingRule MergeDivMulArithmetic() { // (-x) / 2 = x / -2 // 2 / (-x) = 2 / -x FoldingRule MergeDivNegateArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv || inst->opcode() == SpvOpUDiv); - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -859,7 +892,7 @@ FoldingRule MergeDivNegateArithmetic() { const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; - ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); + Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; @@ -889,10 +922,9 @@ FoldingRule MergeDivNegateArithmetic() { // (-x) + 2 = 2 - x // 2 + (-x) = 2 - x FoldingRule MergeAddNegateArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd); - ir::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); bool uses_float = HasFloatingPoint(type); @@ -900,7 +932,7 @@ FoldingRule MergeAddNegateArithmetic() { const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; - ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); + Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; @@ -923,10 +955,9 @@ FoldingRule MergeAddNegateArithmetic() { // (-x) - 2 = -2 - x // 2 - (-x) = x + 2 FoldingRule MergeSubNegateArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub); - ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); @@ -938,7 +969,7 @@ FoldingRule MergeSubNegateArithmetic() { const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; - ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); + Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; @@ -972,10 +1003,9 @@ FoldingRule MergeSubNegateArithmetic() { // 2 + (x + 2) = x + 4 // 2 + (2 + x) = x + 4 FoldingRule MergeAddAddArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd); - ir::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); @@ -987,7 +1017,7 @@ FoldingRule MergeAddAddArithmetic() { const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; - ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); + Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; @@ -998,7 +1028,7 @@ FoldingRule MergeAddAddArithmetic() { const analysis::Constant* const_input2 = ConstInput(other_constants); if (!const_input2) return false; - ir::Instruction* non_const_input = + Instruction* non_const_input = NonConstInput(context, other_constants[0], other_inst); uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), const_input1, const_input2); @@ -1020,10 +1050,9 @@ FoldingRule MergeAddAddArithmetic() { // 2 + (x - 2) = x + 0 // 2 + (2 - x) = 4 - x FoldingRule MergeAddSubArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd); - ir::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); @@ -1035,7 +1064,7 @@ FoldingRule MergeAddSubArithmetic() { const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; - ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); + Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; @@ -1080,10 +1109,9 @@ FoldingRule MergeAddSubArithmetic() { // 2 - (x + 2) = 0 - x // 2 - (2 + x) = 0 - x FoldingRule MergeSubAddArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub); - ir::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); @@ -1095,7 +1123,7 @@ FoldingRule MergeSubAddArithmetic() { const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; - ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); + Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; @@ -1106,7 +1134,7 @@ FoldingRule MergeSubAddArithmetic() { const analysis::Constant* const_input2 = ConstInput(other_constants); if (!const_input2) return false; - ir::Instruction* non_const_input = + Instruction* non_const_input = NonConstInput(context, other_constants[0], other_inst); // If the first operand of the sub is not a constant, swap the constants @@ -1146,10 +1174,9 @@ FoldingRule MergeSubAddArithmetic() { // 2 - (x - 2) = 4 - x // 2 - (2 - x) = x + 0 FoldingRule MergeSubSubArithmetic() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub); - ir::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); @@ -1161,7 +1188,7 @@ FoldingRule MergeSubSubArithmetic() { const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; - ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); + Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; @@ -1172,7 +1199,7 @@ FoldingRule MergeSubSubArithmetic() { const analysis::Constant* const_input2 = ConstInput(other_constants); if (!const_input2) return false; - ir::Instruction* non_const_input = + Instruction* non_const_input = NonConstInput(context, other_constants[0], other_inst); // Merge the constants. @@ -1213,7 +1240,7 @@ FoldingRule MergeSubSubArithmetic() { } FoldingRule IntMultipleBy1() { - return [](ir::Instruction* inst, + return [](IRContext*, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul."); for (uint32_t i = 0; i < 2; i++) { @@ -1239,22 +1266,22 @@ FoldingRule IntMultipleBy1() { } FoldingRule CompositeConstructFeedingExtract() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector&) { // If the input to an OpCompositeExtract is an OpCompositeConstruct, // then we can simply use the appropriate element in the construction. assert(inst->opcode() == SpvOpCompositeExtract && "Wrong opcode. Should be OpCompositeExtract."); - analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); - analysis::TypeManager* type_mgr = inst->context()->get_type_mgr(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); - ir::Instruction* cinst = def_use_mgr->GetDef(cid); + Instruction* cinst = def_use_mgr->GetDef(cid); if (cinst->opcode() != SpvOpCompositeConstruct) { return false; } - std::vector operands; + std::vector operands; analysis::Type* composite_type = type_mgr->GetType(cinst->type_id()); if (composite_type->AsVector() == nullptr) { // Get the element being extracted from the OpCompositeConstruct @@ -1279,7 +1306,7 @@ FoldingRule CompositeConstructFeedingExtract() { for (uint32_t construct_index = 0; construct_index < cinst->NumInOperands(); ++construct_index) { uint32_t element_id = cinst->GetSingleWordInOperand(construct_index); - ir::Instruction* element_def = def_use_mgr->GetDef(element_id); + Instruction* element_def = def_use_mgr->GetDef(element_id); analysis::Vector* element_type = type_mgr->GetType(element_def->type_id())->AsVector(); if (element_type) { @@ -1324,11 +1351,11 @@ FoldingRule CompositeExtractFeedingConstruct() { // // This is a common code pattern because of the way that scalar replacement // works. - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector&) { assert(inst->opcode() == SpvOpCompositeConstruct && "Wrong opcode. Should be OpCompositeConstruct."); - analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); uint32_t original_id = 0; // Check each element to make sure they are: @@ -1337,7 +1364,7 @@ FoldingRule CompositeExtractFeedingConstruct() { // - all extract from the same id. for (uint32_t i = 0; i < inst->NumInOperands(); ++i) { uint32_t element_id = inst->GetSingleWordInOperand(i); - ir::Instruction* element_inst = def_use_mgr->GetDef(element_id); + Instruction* element_inst = def_use_mgr->GetDef(element_id); if (element_inst->opcode() != SpvOpCompositeExtract) { return false; @@ -1362,7 +1389,7 @@ FoldingRule CompositeExtractFeedingConstruct() { // The last check it to see that the object being extracted from is the // correct type. - ir::Instruction* original_inst = def_use_mgr->GetDef(original_id); + Instruction* original_inst = def_use_mgr->GetDef(original_id); if (original_inst->type_id() != inst->type_id()) { return false; } @@ -1375,13 +1402,13 @@ FoldingRule CompositeExtractFeedingConstruct() { } FoldingRule InsertFeedingExtract() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector&) { assert(inst->opcode() == SpvOpCompositeExtract && "Wrong opcode. Should be OpCompositeExtract."); - analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); - ir::Instruction* cinst = def_use_mgr->GetDef(cid); + Instruction* cinst = def_use_mgr->GetDef(cid); if (cinst->opcode() != SpvOpCompositeInsert) { return false; @@ -1419,7 +1446,7 @@ FoldingRule InsertFeedingExtract() { // Extracting an element of the value that was inserted. Extract from // that value directly. if (i + 1 == cinst->NumInOperands()) { - std::vector operands; + std::vector operands; operands.push_back( {SPV_OPERAND_TYPE_ID, {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}); @@ -1433,7 +1460,7 @@ FoldingRule InsertFeedingExtract() { // Extracting a value that is disjoint from the element being inserted. // Rewrite the extract to use the composite input to the insert. - std::vector operands; + std::vector operands; operands.push_back( {SPV_OPERAND_TYPE_ID, {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}}); @@ -1450,21 +1477,21 @@ FoldingRule InsertFeedingExtract() { // operands of the VectorShuffle. We just need to adjust the index in the // extract instruction. FoldingRule VectorShuffleFeedingExtract() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector&) { assert(inst->opcode() == SpvOpCompositeExtract && "Wrong opcode. Should be OpCompositeExtract."); - analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); - analysis::TypeManager* type_mgr = inst->context()->get_type_mgr(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); - ir::Instruction* cinst = def_use_mgr->GetDef(cid); + Instruction* cinst = def_use_mgr->GetDef(cid); if (cinst->opcode() != SpvOpVectorShuffle) { return false; } // Find the size of the first vector operand of the VectorShuffle - ir::Instruction* first_input = + Instruction* first_input = def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0)); analysis::Type* first_input_type = type_mgr->GetType(first_input->type_id()); @@ -1477,6 +1504,14 @@ FoldingRule VectorShuffleFeedingExtract() { uint32_t new_index = cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1)); + // Extracting an undefined value so fold this extract into an undef. + const uint32_t undef_literal_value = 0xffffffff; + if (new_index == undef_literal_value) { + inst->SetOpcode(SpvOpUndef); + inst->SetInOperands({}); + return true; + } + // Get the id of the of the vector the elemtent comes from, and update the // index if needed. uint32_t new_vector = 0; @@ -1494,55 +1529,117 @@ FoldingRule VectorShuffleFeedingExtract() { }; } +// When an FMix with is feeding an Extract that extracts an element whose +// corresponding |a| in the FMix is 0 or 1, we can extract from one of the +// operands of the FMix. +FoldingRule FMixFeedingExtract() { + return [](IRContext* context, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == SpvOpCompositeExtract && + "Wrong opcode. Should be OpCompositeExtract."); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + + uint32_t composite_id = + inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); + Instruction* composite_inst = def_use_mgr->GetDef(composite_id); + + if (composite_inst->opcode() != SpvOpExtInst) { + return false; + } + + uint32_t inst_set_id = + context->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); + + if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) != + inst_set_id || + composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) != + GLSLstd450FMix) { + return false; + } + + // Get the |a| for the FMix instruction. + uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx); + std::unique_ptr a(inst->Clone(context)); + a->SetInOperand(kExtractCompositeIdInIdx, {a_id}); + context->get_instruction_folder().FoldInstruction(a.get()); + + if (a->opcode() != SpvOpCopyObject) { + return false; + } + + const analysis::Constant* a_const = + const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0)); + + if (!a_const) { + return false; + } + + bool use_x = false; + + assert(a_const->type()->AsFloat()); + double element_value = a_const->GetValueAsDouble(); + if (element_value == 0.0) { + use_x = true; + } else if (element_value == 1.0) { + use_x = false; + } else { + return false; + } + + // Get the id of the of the vector the element comes from. + uint32_t new_vector = 0; + if (use_x) { + new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx); + } else { + new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx); + } + + // Update the extract instruction. + inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector}); + return true; + }; +} + FoldingRule RedundantPhi() { // An OpPhi instruction where all values are the same or the result of the phi // itself, can be replaced by the value itself. - return - [](ir::Instruction* inst, const std::vector&) { - assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi."); + return [](IRContext*, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi."); - ir::IRContext* context = inst->context(); - analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + uint32_t incoming_value = 0; - uint32_t incoming_value = 0; + for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) { + uint32_t op_id = inst->GetSingleWordInOperand(i); + if (op_id == inst->result_id()) { + continue; + } - for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) { - uint32_t op_id = inst->GetSingleWordInOperand(i); - if (op_id == inst->result_id()) { - continue; - } + if (incoming_value == 0) { + incoming_value = op_id; + } else if (op_id != incoming_value) { + // Found two possible value. Can't simplify. + return false; + } + } - ir::Instruction* op_inst = def_use_mgr->GetDef(op_id); - if (op_inst->opcode() == SpvOpUndef) { - // TODO: We should be able to still use op_id if we know that - // the definition of op_id dominates |inst|. - return false; - } + if (incoming_value == 0) { + // Code looks invalid. Don't do anything. + return false; + } - if (incoming_value == 0) { - incoming_value = op_id; - } else if (op_id != incoming_value) { - // Found two possible value. Can't simplify. - return false; - } - } - - if (incoming_value == 0) { - // Code looks invalid. Don't do anything. - return false; - } - - // We have a single incoming value. Simplify using that value. - inst->SetOpcode(SpvOpCopyObject); - inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}}); - return true; - }; + // We have a single incoming value. Simplify using that value. + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}}); + return true; + }; } FoldingRule RedundantSelect() { // An OpSelect instruction where both values are the same or the condition is // constant can be replaced by one of the values - return [](ir::Instruction* inst, + return [](IRContext*, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpSelect && "Wrong opcode. Should be OpSelect."); @@ -1578,7 +1675,7 @@ FoldingRule RedundantSelect() { return true; } else { // Convert to a vector shuffle. - std::vector ops; + std::vector ops; ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}}); ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}}); const analysis::VectorConstant* vector_const = @@ -1658,7 +1755,7 @@ FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) { } FoldingRule RedundantFAdd() { - return [](ir::Instruction* inst, + return [](IRContext*, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd."); assert(constants.size() == 2); @@ -1683,7 +1780,7 @@ FoldingRule RedundantFAdd() { } FoldingRule RedundantFSub() { - return [](ir::Instruction* inst, + return [](IRContext*, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub."); assert(constants.size() == 2); @@ -1714,7 +1811,7 @@ FoldingRule RedundantFSub() { } FoldingRule RedundantFMul() { - return [](ir::Instruction* inst, + return [](IRContext*, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul."); assert(constants.size() == 2); @@ -1747,7 +1844,7 @@ FoldingRule RedundantFMul() { } FoldingRule RedundantFDiv() { - return [](ir::Instruction* inst, + return [](IRContext*, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv."); assert(constants.size() == 2); @@ -1778,7 +1875,7 @@ FoldingRule RedundantFDiv() { } FoldingRule RedundantFMix() { - return [](ir::Instruction* inst, + return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpExtInst && "Wrong opcode. Should be OpExtInst."); @@ -1788,7 +1885,7 @@ FoldingRule RedundantFMix() { } uint32_t instSetId = - inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); + context->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId && inst->GetSingleWordInOperand(kExtInstInstructionInIdx) == @@ -1812,19 +1909,279 @@ FoldingRule RedundantFMix() { }; } +// This rule handles addition of zero for integers. +FoldingRule RedundantIAdd() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd."); + + uint32_t operand = std::numeric_limits::max(); + const analysis::Type* operand_type = nullptr; + if (constants[0] && constants[0]->IsZero()) { + operand = inst->GetSingleWordInOperand(1); + operand_type = constants[0]->type(); + } else if (constants[1] && constants[1]->IsZero()) { + operand = inst->GetSingleWordInOperand(0); + operand_type = constants[1]->type(); + } + + if (operand != std::numeric_limits::max()) { + const analysis::Type* inst_type = + context->get_type_mgr()->GetType(inst->type_id()); + if (inst_type->IsSame(operand_type)) { + inst->SetOpcode(SpvOpCopyObject); + } else { + inst->SetOpcode(SpvOpBitcast); + } + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}}); + return true; + } + return false; + }; +} + +// This rule look for a dot with a constant vector containing a single 1 and +// the rest 0s. This is the same as doing an extract. +FoldingRule DotProductDoingExtract() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot."); + + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return false; + } + + for (int i = 0; i < 2; ++i) { + if (!constants[i]) { + continue; + } + + const analysis::Vector* vector_type = constants[i]->type()->AsVector(); + assert(vector_type && "Inputs to OpDot must be vectors."); + const analysis::Float* element_type = + vector_type->element_type()->AsFloat(); + assert(element_type && "Inputs to OpDot must be vectors of floats."); + uint32_t element_width = element_type->width(); + if (element_width != 32 && element_width != 64) { + return false; + } + + std::vector components; + components = constants[i]->GetVectorComponents(const_mgr); + + const uint32_t kNotFound = std::numeric_limits::max(); + + uint32_t component_with_one = kNotFound; + bool all_others_zero = true; + for (uint32_t j = 0; j < components.size(); ++j) { + const analysis::Constant* element = components[j]; + double value = + (element_width == 32 ? element->GetFloat() : element->GetDouble()); + if (value == 0.0) { + continue; + } else if (value == 1.0) { + if (component_with_one == kNotFound) { + component_with_one = j; + } else { + component_with_one = kNotFound; + break; + } + } else { + all_others_zero = false; + break; + } + } + + if (!all_others_zero || component_with_one == kNotFound) { + continue; + } + + std::vector operands; + operands.push_back( + {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}}); + operands.push_back( + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}}); + + inst->SetOpcode(SpvOpCompositeExtract); + inst->SetInOperands(std::move(operands)); + return true; + } + return false; + }; +} + +// If we are storing an undef, then we can remove the store. +// +// TODO: We can do something similar for OpImageWrite, but checking for volatile +// is complicated. Waiting to see if it is needed. +FoldingRule StoringUndef() { + return [](IRContext* context, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore."); + + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + + // If this is a volatile store, the store cannot be removed. + if (inst->NumInOperands() == 3) { + if (inst->GetSingleWordInOperand(3) & SpvMemoryAccessVolatileMask) { + return false; + } + } + + uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx); + Instruction* object_inst = def_use_mgr->GetDef(object_id); + if (object_inst->opcode() == SpvOpUndef) { + inst->ToNop(); + return true; + } + return false; + }; +} + +FoldingRule VectorShuffleFeedingShuffle() { + return [](IRContext* context, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == SpvOpVectorShuffle && + "Wrong opcode. Should be OpVectorShuffle."); + + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + + Instruction* feeding_shuffle_inst = + def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + analysis::Vector* op0_type = + type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector(); + uint32_t op0_length = op0_type->element_count(); + + bool feeder_is_op0 = true; + if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) { + feeding_shuffle_inst = + def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); + feeder_is_op0 = false; + } + + if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) { + return false; + } + + Instruction* feeder2 = + def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0)); + analysis::Vector* feeder_op0_type = + type_mgr->GetType(feeder2->type_id())->AsVector(); + uint32_t feeder_op0_length = feeder_op0_type->element_count(); + + uint32_t new_feeder_id = 0; + std::vector new_operands; + new_operands.resize( + 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands. + const uint32_t undef_literal = 0xffffffff; + for (uint32_t op = 2; op < inst->NumInOperands(); ++op) { + uint32_t component_index = inst->GetSingleWordInOperand(op); + + // Do not interpret the undefined value literal as coming from operand 1. + if (component_index != undef_literal && + feeder_is_op0 == (component_index < op0_length)) { + // This component comes from the feeding_shuffle_inst. Update + // |component_index| to be the index into the operand of the feeder. + + // Adjust component_index to get the index into the operands of the + // feeding_shuffle_inst. + if (component_index >= op0_length) { + component_index -= op0_length; + } + component_index = + feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2); + + // Check if we are using a component from the first or second operand of + // the feeding instruction. + if (component_index < feeder_op0_length) { + if (new_feeder_id == 0) { + // First time through, save the id of the operand the element comes + // from. + new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0); + } else if (new_feeder_id != + feeding_shuffle_inst->GetSingleWordInOperand(0)) { + // We need both elements of the feeding_shuffle_inst, so we cannot + // fold. + return false; + } + } else { + if (new_feeder_id == 0) { + // First time through, save the id of the operand the element comes + // from. + new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1); + } else if (new_feeder_id != + feeding_shuffle_inst->GetSingleWordInOperand(1)) { + // We need both elements of the feeding_shuffle_inst, so we cannot + // fold. + return false; + } + component_index -= feeder_op0_length; + } + + if (!feeder_is_op0) { + component_index += op0_length; + } + } + new_operands.push_back( + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}}); + } + + if (new_feeder_id == 0) { + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Type* type = + type_mgr->GetType(feeding_shuffle_inst->type_id()); + const analysis::Constant* null_const = const_mgr->GetConstant(type, {}); + new_feeder_id = + const_mgr->GetDefiningInstruction(null_const, 0)->result_id(); + } + + if (feeder_is_op0) { + // If the size of the first vector operand changed then the indices + // referring to the second operand need to be adjusted. + Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id); + analysis::Type* new_feeder_type = + type_mgr->GetType(new_feeder_inst->type_id()); + uint32_t new_op0_size = new_feeder_type->AsVector()->element_count(); + int32_t adjustment = op0_length - new_op0_size; + + if (adjustment != 0) { + for (uint32_t i = 2; i < new_operands.size(); i++) { + if (inst->GetSingleWordInOperand(i) >= op0_length) { + new_operands[i].words[0] -= adjustment; + } + } + } + + new_operands[0].words[0] = new_feeder_id; + new_operands[1] = inst->GetInOperand(1); + } else { + new_operands[1].words[0] = new_feeder_id; + new_operands[0] = inst->GetInOperand(0); + } + + inst->SetInOperands(std::move(new_operands)); + return true; + }; +} + } // namespace -spvtools::opt::FoldingRules::FoldingRules() { +FoldingRules::FoldingRules() { // Add all folding rules to the list for the opcodes to which they apply. // Note that the order in which rules are added to the list matters. If a rule // applies to the instruction, the rest of the rules will not be attempted. // Take that into consideration. - rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct()); rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract()); rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract()); rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract()); + rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract()); + + rules_[SpvOpDot].push_back(DotProductDoingExtract()); rules_[SpvOpExtInst].push_back(RedundantFMix()); @@ -1853,6 +2210,7 @@ spvtools::opt::FoldingRules::FoldingRules() { rules_[SpvOpFSub].push_back(MergeSubAddArithmetic()); rules_[SpvOpFSub].push_back(MergeSubSubArithmetic()); + rules_[SpvOpIAdd].push_back(RedundantIAdd()); rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic()); rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic()); rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic()); @@ -1875,8 +2233,11 @@ spvtools::opt::FoldingRules::FoldingRules() { rules_[SpvOpSelect].push_back(RedundantSelect()); - rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic()); -} + rules_[SpvOpStore].push_back(StoringUndef()); + rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic()); + + rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle()); +} } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/folding_rules.h b/3rdparty/spirv-tools/source/opt/folding_rules.h index 78277e82c..33fdbffe9 100644 --- a/3rdparty/spirv-tools/source/opt/folding_rules.h +++ b/3rdparty/spirv-tools/source/opt/folding_rules.h @@ -12,16 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_UTIL_FOLDING_RULES_H_ -#define LIBSPIRV_UTIL_FOLDING_RULES_H_ +#ifndef SOURCE_OPT_FOLDING_RULES_H_ +#define SOURCE_OPT_FOLDING_RULES_H_ #include +#include #include -#include "constants.h" -#include "def_use_manager.h" -#include "ir_builder.h" -#include "ir_context.h" +#include "source/opt/constants.h" namespace spvtools { namespace opt { @@ -55,14 +53,14 @@ namespace opt { // the later rules will not be attempted. using FoldingRule = std::function& constants)>; class FoldingRules { public: FoldingRules(); - const std::vector& GetRulesForOpcode(SpvOp opcode) { + const std::vector& GetRulesForOpcode(SpvOp opcode) const { auto it = rules_.find(opcode); if (it != rules_.end()) { return it->second; @@ -78,4 +76,4 @@ class FoldingRules { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_UTIL_FOLDING_RULES_H_ +#endif // SOURCE_OPT_FOLDING_RULES_H_ diff --git a/3rdparty/spirv-tools/source/opt/freeze_spec_constant_value_pass.cpp b/3rdparty/spirv-tools/source/opt/freeze_spec_constant_value_pass.cpp index ef589b03f..10e98fd8b 100644 --- a/3rdparty/spirv-tools/source/opt/freeze_spec_constant_value_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/freeze_spec_constant_value_pass.cpp @@ -12,40 +12,40 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "freeze_spec_constant_value_pass.h" -#include "ir_context.h" +#include "source/opt/freeze_spec_constant_value_pass.h" +#include "source/opt/ir_context.h" namespace spvtools { namespace opt { -Pass::Status FreezeSpecConstantValuePass::Process(ir::IRContext* irContext) { +Pass::Status FreezeSpecConstantValuePass::Process() { bool modified = false; - irContext->module()->ForEachInst( - [&modified, irContext](ir::Instruction* inst) { - switch (inst->opcode()) { - case SpvOp::SpvOpSpecConstant: - inst->SetOpcode(SpvOp::SpvOpConstant); - modified = true; - break; - case SpvOp::SpvOpSpecConstantTrue: - inst->SetOpcode(SpvOp::SpvOpConstantTrue); - modified = true; - break; - case SpvOp::SpvOpSpecConstantFalse: - inst->SetOpcode(SpvOp::SpvOpConstantFalse); - modified = true; - break; - case SpvOp::SpvOpDecorate: - if (inst->GetSingleWordInOperand(1) == - SpvDecoration::SpvDecorationSpecId) { - irContext->KillInst(inst); - modified = true; - } - break; - default: - break; + auto ctx = context(); + ctx->module()->ForEachInst([&modified, ctx](Instruction* inst) { + switch (inst->opcode()) { + case SpvOp::SpvOpSpecConstant: + inst->SetOpcode(SpvOp::SpvOpConstant); + modified = true; + break; + case SpvOp::SpvOpSpecConstantTrue: + inst->SetOpcode(SpvOp::SpvOpConstantTrue); + modified = true; + break; + case SpvOp::SpvOpSpecConstantFalse: + inst->SetOpcode(SpvOp::SpvOpConstantFalse); + modified = true; + break; + case SpvOp::SpvOpDecorate: + if (inst->GetSingleWordInOperand(1) == + SpvDecoration::SpvDecorationSpecId) { + ctx->KillInst(inst); + modified = true; } - }); + break; + default: + break; + } + }); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } diff --git a/3rdparty/spirv-tools/source/opt/freeze_spec_constant_value_pass.h b/3rdparty/spirv-tools/source/opt/freeze_spec_constant_value_pass.h index fc7f44e70..0663adf40 100644 --- a/3rdparty/spirv-tools/source/opt/freeze_spec_constant_value_pass.h +++ b/3rdparty/spirv-tools/source/opt/freeze_spec_constant_value_pass.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_FREEZE_SPEC_CONSTANT_VALUE_PASS_H_ -#define LIBSPIRV_OPT_FREEZE_SPEC_CONSTANT_VALUE_PASS_H_ +#ifndef SOURCE_OPT_FREEZE_SPEC_CONSTANT_VALUE_PASS_H_ +#define SOURCE_OPT_FREEZE_SPEC_CONSTANT_VALUE_PASS_H_ -#include "ir_context.h" -#include "module.h" -#include "pass.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -26,10 +26,10 @@ namespace opt { class FreezeSpecConstantValuePass : public Pass { public: const char* name() const override { return "freeze-spec-const"; } - Status Process(ir::IRContext*) override; + Status Process() override; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_FREEZE_SPEC_CONSTANT_VALUE_PASS_H_ +#endif // SOURCE_OPT_FREEZE_SPEC_CONSTANT_VALUE_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/function.cpp b/3rdparty/spirv-tools/source/opt/function.cpp index 5a648eb11..c6894c681 100644 --- a/3rdparty/spirv-tools/source/opt/function.cpp +++ b/3rdparty/spirv-tools/source/opt/function.cpp @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "function.h" +#include "source/opt/function.h" #include #include namespace spvtools { -namespace ir { +namespace opt { Function* Function::Clone(IRContext* ctx) const { Function* clone = @@ -96,7 +96,7 @@ std::ostream& operator<<(std::ostream& str, const Function& func) { std::string Function::PrettyPrint(uint32_t options) const { std::ostringstream str; - ForEachInst([&str, options](const ir::Instruction* inst) { + ForEachInst([&str, options](const Instruction* inst) { str << inst->PrettyPrint(options); if (inst->opcode() != SpvOpFunctionEnd) { str << std::endl; @@ -105,5 +105,5 @@ std::string Function::PrettyPrint(uint32_t options) const { return str.str(); } -} // namespace ir +} // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/function.h b/3rdparty/spirv-tools/source/opt/function.h index c4d4c613a..4dc5d25a6 100644 --- a/3rdparty/spirv-tools/source/opt/function.h +++ b/3rdparty/spirv-tools/source/opt/function.h @@ -12,21 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_CONSTRUCTS_H_ -#define LIBSPIRV_OPT_CONSTRUCTS_H_ +#ifndef SOURCE_OPT_FUNCTION_H_ +#define SOURCE_OPT_FUNCTION_H_ #include #include #include +#include #include #include -#include "basic_block.h" -#include "instruction.h" -#include "iterator.h" +#include "source/opt/basic_block.h" +#include "source/opt/instruction.h" +#include "source/opt/iterator.h" namespace spvtools { -namespace ir { +namespace opt { class CFG; class IRContext; @@ -53,10 +54,6 @@ class Function { Instruction& DefInst() { return *def_inst_; } const Instruction& DefInst() const { return *def_inst_; } - // Sets the enclosing module for this function. - void SetParent(Module* module) { module_ = module; } - // Gets the enclosing module for this function - Module* GetParent() const { return module_; } // Appends a parameter to this function. inline void AddParameter(std::unique_ptr p); // Appends a basic block to this function. @@ -66,6 +63,13 @@ class Function { template inline void AddBasicBlocks(T begin, T end, iterator ip); + // Move basic block with |id| to the position after |ip|. Both have to be + // contained in this function. + inline void MoveBasicBlockToAfter(uint32_t id, BasicBlock* ip); + + // Delete all basic blocks that contain no instructions. + inline void RemoveEmptyBlocks(); + // Saves the given function end instruction. inline void SetFunctionEnd(std::unique_ptr end_inst); @@ -95,7 +99,7 @@ class Function { // Returns an iterator to the basic block |id|. iterator FindBlock(uint32_t bb_id) { - return std::find_if(begin(), end(), [bb_id](const ir::BasicBlock& it_bb) { + return std::find_if(begin(), end(), [bb_id](const BasicBlock& it_bb) { return bb_id == it_bb.id(); }); } @@ -112,10 +116,7 @@ class Function { void ForEachParam(const std::function& f, bool run_on_debug_line_insts = false) const; - // Returns the context of the current function. - IRContext* context() const { return def_inst_->context(); } - - BasicBlock* InsertBasicBlockAfter(std::unique_ptr&& new_block, + BasicBlock* InsertBasicBlockAfter(std::unique_ptr&& new_block, BasicBlock* position); // Pretty-prints all the basic blocks in this function into a std::string. @@ -125,8 +126,6 @@ class Function { std::string PrettyPrint(uint32_t options = 0u) const; private: - // The enclosing module. - Module* module_; // The OpFunction instruction that begins the definition of this function. std::unique_ptr def_inst_; // All parameters to this function. @@ -141,7 +140,7 @@ class Function { std::ostream& operator<<(std::ostream& str, const Function& func); inline Function::Function(std::unique_ptr def_inst) - : module_(nullptr), def_inst_(std::move(def_inst)), end_inst_() {} + : def_inst_(std::move(def_inst)), end_inst_() {} inline void Function::AddParameter(std::unique_ptr p) { params_.emplace_back(std::move(p)); @@ -162,11 +161,30 @@ inline void Function::AddBasicBlocks(T src_begin, T src_end, iterator ip) { std::make_move_iterator(src_end)); } +inline void Function::MoveBasicBlockToAfter(uint32_t id, BasicBlock* ip) { + auto block_to_move = std::move(*FindBlock(id).Get()); + + assert(block_to_move->GetParent() == ip->GetParent() && + "Both blocks have to be in the same function."); + + InsertBasicBlockAfter(std::move(block_to_move), ip); + blocks_.erase(std::find(std::begin(blocks_), std::end(blocks_), nullptr)); +} + +inline void Function::RemoveEmptyBlocks() { + auto first_empty = + std::remove_if(std::begin(blocks_), std::end(blocks_), + [](const std::unique_ptr& bb) -> bool { + return bb->GetLabelInst()->opcode() == SpvOpNop; + }); + blocks_.erase(first_empty, std::end(blocks_)); +} + inline void Function::SetFunctionEnd(std::unique_ptr end_inst) { end_inst_ = std::move(end_inst); } -} // namespace ir +} // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_CONSTRUCTS_H_ +#endif // SOURCE_OPT_FUNCTION_H_ diff --git a/3rdparty/spirv-tools/source/opt/if_conversion.cpp b/3rdparty/spirv-tools/source/opt/if_conversion.cpp index 951b7ea45..7a3717f98 100644 --- a/3rdparty/spirv-tools/source/opt/if_conversion.cpp +++ b/3rdparty/spirv-tools/source/opt/if_conversion.cpp @@ -12,23 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "if_conversion.h" +#include "source/opt/if_conversion.h" + +#include +#include + +#include "source/opt/value_number_table.h" namespace spvtools { namespace opt { -Pass::Status IfConversion::Process(ir::IRContext* c) { - InitializeProcessing(c); - +Pass::Status IfConversion::Process() { + const ValueNumberTable& vn_table = *context()->GetValueNumberTable(); bool modified = false; - std::vector to_kill; + std::vector to_kill; for (auto& func : *get_module()) { - DominatorAnalysis* dominators = - context()->GetDominatorAnalysis(&func, *cfg()); + DominatorAnalysis* dominators = context()->GetDominatorAnalysis(&func); for (auto& block : func) { // Check if it is possible for |block| to have phis that can be // transformed. - ir::BasicBlock* common = nullptr; + BasicBlock* common = nullptr; if (!CheckBlock(&block, dominators, &common)) continue; // Get an insertion point. @@ -39,10 +42,9 @@ Pass::Status IfConversion::Process(ir::IRContext* c) { InstructionBuilder builder( context(), &*iter, - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping); + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); block.ForEachPhiInst([this, &builder, &modified, &common, &to_kill, - dominators, &block](ir::Instruction* phi) { + dominators, &block, &vn_table](Instruction* phi) { // This phi is not compatible, but subsequent phis might be. if (!CheckType(phi->type_id())) return; @@ -56,13 +58,12 @@ Pass::Status IfConversion::Process(ir::IRContext* c) { // branches. If |then_block| dominates |inc0| or if the true edge // branches straight to this block and |common| is |inc0|, then |inc0| // is on the true branch. Otherwise the |inc1| is on the true branch. - ir::BasicBlock* inc0 = GetIncomingBlock(phi, 0u); - ir::Instruction* branch = common->terminator(); + BasicBlock* inc0 = GetIncomingBlock(phi, 0u); + Instruction* branch = common->terminator(); uint32_t condition = branch->GetSingleWordInOperand(0u); - ir::BasicBlock* then_block = - GetBlock(branch->GetSingleWordInOperand(1u)); - ir::Instruction* true_value = nullptr; - ir::Instruction* false_value = nullptr; + BasicBlock* then_block = GetBlock(branch->GetSingleWordInOperand(1u)); + Instruction* true_value = nullptr; + Instruction* false_value = nullptr; if ((then_block == &block && inc0 == common) || dominators->Dominates(then_block, inc0)) { true_value = GetIncomingValue(phi, 0u); @@ -72,16 +73,46 @@ Pass::Status IfConversion::Process(ir::IRContext* c) { false_value = GetIncomingValue(phi, 0u); } + BasicBlock* true_def_block = context()->get_instr_block(true_value); + BasicBlock* false_def_block = context()->get_instr_block(false_value); + + uint32_t true_vn = vn_table.GetValueNumber(true_value); + uint32_t false_vn = vn_table.GetValueNumber(false_value); + if (true_vn != 0 && true_vn == false_vn) { + Instruction* inst_to_use = nullptr; + + // Try to pick an instruction that is not in a side node. If we can't + // pick either the true for false branch as long as they can be + // legally moved. + if (!true_def_block || + dominators->Dominates(true_def_block, &block)) { + inst_to_use = true_value; + } else if (!false_def_block || + dominators->Dominates(false_def_block, &block)) { + inst_to_use = false_value; + } else if (CanHoistInstruction(true_value, common, dominators)) { + inst_to_use = true_value; + } else if (CanHoistInstruction(false_value, common, dominators)) { + inst_to_use = false_value; + } + + if (inst_to_use != nullptr) { + modified = true; + HoistInstruction(inst_to_use, common, dominators); + context()->KillNamesAndDecorates(phi); + context()->ReplaceAllUsesWith(phi->result_id(), + inst_to_use->result_id()); + } + return; + } + // If either incoming value is defined in a block that does not dominate // this phi, then we cannot eliminate the phi with a select. // TODO(alan-baker): Perform code motion where it makes sense to enable // the transform in this case. - ir::BasicBlock* true_def_block = context()->get_instr_block(true_value); if (true_def_block && !dominators->Dominates(true_def_block, &block)) return; - ir::BasicBlock* false_def_block = - context()->get_instr_block(false_value); if (false_def_block && !dominators->Dominates(false_def_block, &block)) return; @@ -91,9 +122,9 @@ Pass::Status IfConversion::Process(ir::IRContext* c) { condition = SplatCondition(vec_data_ty, condition, &builder); } - ir::Instruction* select = builder.AddSelect(phi->type_id(), condition, - true_value->result_id(), - false_value->result_id()); + Instruction* select = builder.AddSelect(phi->type_id(), condition, + true_value->result_id(), + false_value->result_id()); context()->ReplaceAllUsesWith(phi->result_id(), select->result_id()); to_kill.push_back(phi); modified = true; @@ -110,18 +141,17 @@ Pass::Status IfConversion::Process(ir::IRContext* c) { return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } -bool IfConversion::CheckBlock(ir::BasicBlock* block, - DominatorAnalysis* dominators, - ir::BasicBlock** common) { +bool IfConversion::CheckBlock(BasicBlock* block, DominatorAnalysis* dominators, + BasicBlock** common) { const std::vector& preds = cfg()->preds(block->id()); // TODO(alan-baker): Extend to more than two predecessors if (preds.size() != 2) return false; - ir::BasicBlock* inc0 = context()->get_instr_block(preds[0]); + BasicBlock* inc0 = context()->get_instr_block(preds[0]); if (dominators->Dominates(block, inc0)) return false; - ir::BasicBlock* inc1 = context()->get_instr_block(preds[1]); + BasicBlock* inc1 = context()->get_instr_block(preds[1]); if (dominators->Dominates(block, inc1)) return false; // All phis will have the same common dominator, so cache the result @@ -129,15 +159,15 @@ bool IfConversion::CheckBlock(ir::BasicBlock* block, // any phi in this basic block. *common = dominators->CommonDominator(inc0, inc1); if (!*common || cfg()->IsPseudoEntryBlock(*common)) return false; - ir::Instruction* branch = (*common)->terminator(); + Instruction* branch = (*common)->terminator(); if (branch->opcode() != SpvOpBranchConditional) return false; return true; } -bool IfConversion::CheckPhiUsers(ir::Instruction* phi, ir::BasicBlock* block) { +bool IfConversion::CheckPhiUsers(Instruction* phi, BasicBlock* block) { return get_def_use_mgr()->WhileEachUser(phi, [block, - this](ir::Instruction* user) { + this](Instruction* user) { if (user->opcode() == SpvOpPhi && context()->get_instr_block(user) == block) return false; return true; @@ -160,7 +190,7 @@ uint32_t IfConversion::SplatCondition(analysis::Vector* vec_data_ty, } bool IfConversion::CheckType(uint32_t id) { - ir::Instruction* type = get_def_use_mgr()->GetDef(id); + Instruction* type = get_def_use_mgr()->GetDef(id); SpvOp op = type->opcode(); if (spvOpcodeIsScalarType(op) || op == SpvOpTypePointer || op == SpvOpTypeVector) @@ -168,21 +198,81 @@ bool IfConversion::CheckType(uint32_t id) { return false; } -ir::BasicBlock* IfConversion::GetBlock(uint32_t id) { +BasicBlock* IfConversion::GetBlock(uint32_t id) { return context()->get_instr_block(get_def_use_mgr()->GetDef(id)); } -ir::BasicBlock* IfConversion::GetIncomingBlock(ir::Instruction* phi, - uint32_t predecessor) { +BasicBlock* IfConversion::GetIncomingBlock(Instruction* phi, + uint32_t predecessor) { uint32_t in_index = 2 * predecessor + 1; return GetBlock(phi->GetSingleWordInOperand(in_index)); } -ir::Instruction* IfConversion::GetIncomingValue(ir::Instruction* phi, - uint32_t predecessor) { +Instruction* IfConversion::GetIncomingValue(Instruction* phi, + uint32_t predecessor) { uint32_t in_index = 2 * predecessor; return get_def_use_mgr()->GetDef(phi->GetSingleWordInOperand(in_index)); } +void IfConversion::HoistInstruction(Instruction* inst, BasicBlock* target_block, + DominatorAnalysis* dominators) { + BasicBlock* inst_block = context()->get_instr_block(inst); + if (!inst_block) { + // This is in the header, and dominates everything. + return; + } + + if (dominators->Dominates(inst_block, target_block)) { + // Already in position. No work to do. + return; + } + + assert(inst->IsOpcodeCodeMotionSafe() && + "Trying to move an instruction that is not safe to move."); + + // First hoist all instructions it depends on. + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + inst->ForEachInId( + [this, target_block, def_use_mgr, dominators](uint32_t* id) { + Instruction* operand_inst = def_use_mgr->GetDef(*id); + HoistInstruction(operand_inst, target_block, dominators); + }); + + Instruction* insertion_pos = target_block->terminator(); + if ((insertion_pos)->PreviousNode()->opcode() == SpvOpSelectionMerge) { + insertion_pos = insertion_pos->PreviousNode(); + } + inst->RemoveFromList(); + insertion_pos->InsertBefore(std::unique_ptr(inst)); + context()->set_instr_block(inst, target_block); +} + +bool IfConversion::CanHoistInstruction(Instruction* inst, + BasicBlock* target_block, + DominatorAnalysis* dominators) { + BasicBlock* inst_block = context()->get_instr_block(inst); + if (!inst_block) { + // This is in the header, and dominates everything. + return true; + } + + if (dominators->Dominates(inst_block, target_block)) { + // Already in position. No work to do. + return true; + } + + if (!inst->IsOpcodeCodeMotionSafe()) { + return false; + } + + // Check all instruction |inst| depends on. + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + return inst->WhileEachInId( + [this, target_block, def_use_mgr, dominators](uint32_t* id) { + Instruction* operand_inst = def_use_mgr->GetDef(*id); + return CanHoistInstruction(operand_inst, target_block, dominators); + }); +} + } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/if_conversion.h b/3rdparty/spirv-tools/source/opt/if_conversion.h index eb97406d4..609bdf392 100644 --- a/3rdparty/spirv-tools/source/opt/if_conversion.h +++ b/3rdparty/spirv-tools/source/opt/if_conversion.h @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_IF_CONVERSION_H_ -#define LIBSPIRV_OPT_IF_CONVERSION_H_ +#ifndef SOURCE_OPT_IF_CONVERSION_H_ +#define SOURCE_OPT_IF_CONVERSION_H_ -#include "basic_block.h" -#include "ir_builder.h" -#include "pass.h" -#include "types.h" +#include "source/opt/basic_block.h" +#include "source/opt/ir_builder.h" +#include "source/opt/pass.h" +#include "source/opt/types.h" namespace spvtools { namespace opt { @@ -27,13 +27,12 @@ namespace opt { class IfConversion : public Pass { public: const char* name() const override { return "if-conversion"; } - Status Process(ir::IRContext* context) override; + Status Process() override; - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisDominatorAnalysis | - ir::IRContext::kAnalysisInstrToBlockMapping | - ir::IRContext::kAnalysisCFG | ir::IRContext::kAnalysisNameMap; + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisCFG | + IRContext::kAnalysisNameMap; } private: @@ -42,14 +41,14 @@ class IfConversion : public Pass { bool CheckType(uint32_t id); // Returns the basic block containing |id|. - ir::BasicBlock* GetBlock(uint32_t id); + BasicBlock* GetBlock(uint32_t id); // Returns the basic block for the |predecessor|'th index predecessor of // |phi|. - ir::BasicBlock* GetIncomingBlock(ir::Instruction* phi, uint32_t predecessor); + BasicBlock* GetIncomingBlock(Instruction* phi, uint32_t predecessor); // Returns the instruction defining the |predecessor|'th index of |phi|. - ir::Instruction* GetIncomingValue(ir::Instruction* phi, uint32_t predecessor); + Instruction* GetIncomingValue(Instruction* phi, uint32_t predecessor); // Returns the id of a OpCompositeConstruct boolean vector. The composite has // the same number of elements as |vec_data_ty| and each member is |cond|. @@ -60,17 +59,30 @@ class IfConversion : public Pass { InstructionBuilder* builder); // Returns true if none of |phi|'s users are in |block|. - bool CheckPhiUsers(ir::Instruction* phi, ir::BasicBlock* block); + bool CheckPhiUsers(Instruction* phi, BasicBlock* block); // Returns |false| if |block| is not appropriate to transform. Only // transforms blocks with two predecessors. Neither incoming block can be // dominated by |block|. Both predecessors must share a common dominator that // is terminated by a conditional branch. - bool CheckBlock(ir::BasicBlock* block, DominatorAnalysis* dominators, - ir::BasicBlock** common); + bool CheckBlock(BasicBlock* block, DominatorAnalysis* dominators, + BasicBlock** common); + + // Moves |inst| to |target_block| if it does not already dominate the block. + // Any instructions that |inst| depends on are move if necessary. It is + // assumed that |inst| can be hoisted to |target_block| as defined by + // |CanHoistInstruction|. |dominators| is the dominator analysis for the + // function that contains |target_block|. + void HoistInstruction(Instruction* inst, BasicBlock* target_block, + DominatorAnalysis* dominators); + + // Returns true if it is legal to move |inst| and the instructions it depends + // on to |target_block| if they do not already dominate |target_block|. + bool CanHoistInstruction(Instruction* inst, BasicBlock* target_block, + DominatorAnalysis* dominators); }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_IF_CONVERSION_H_ +#endif // SOURCE_OPT_IF_CONVERSION_H_ diff --git a/3rdparty/spirv-tools/source/opt/inline_exhaustive_pass.cpp b/3rdparty/spirv-tools/source/opt/inline_exhaustive_pass.cpp index a5bc9f358..5714cd867 100644 --- a/3rdparty/spirv-tools/source/opt/inline_exhaustive_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/inline_exhaustive_pass.cpp @@ -14,20 +14,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "inline_exhaustive_pass.h" +#include "source/opt/inline_exhaustive_pass.h" + +#include namespace spvtools { namespace opt { -bool InlineExhaustivePass::InlineExhaustive(ir::Function* func) { +bool InlineExhaustivePass::InlineExhaustive(Function* func) { bool modified = false; // Using block iterators here because of block erasures and insertions. for (auto bi = func->begin(); bi != func->end(); ++bi) { for (auto ii = bi->begin(); ii != bi->end();) { if (IsInlinableFunctionCall(&*ii)) { // Inline call. - std::vector> newBlocks; - std::vector> newVars; + std::vector> newBlocks; + std::vector> newVars; GenInlineCode(&newBlocks, &newVars, ii, bi); // If call block is replaced with more than one block, point // succeeding phis at new last block. @@ -59,21 +61,17 @@ bool InlineExhaustivePass::InlineExhaustive(ir::Function* func) { return modified; } -void InlineExhaustivePass::Initialize(ir::IRContext* c) { InitializeInline(c); } - Pass::Status InlineExhaustivePass::ProcessImpl() { // Attempt exhaustive inlining on each entry point function in module - ProcessFunction pfn = [this](ir::Function* fp) { - return InlineExhaustive(fp); - }; + ProcessFunction pfn = [this](Function* fp) { return InlineExhaustive(fp); }; bool modified = ProcessEntryPointCallTree(pfn, get_module()); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } -InlineExhaustivePass::InlineExhaustivePass() {} +InlineExhaustivePass::InlineExhaustivePass() = default; -Pass::Status InlineExhaustivePass::Process(ir::IRContext* c) { - Initialize(c); +Pass::Status InlineExhaustivePass::Process() { + InitializeInline(); return ProcessImpl(); } diff --git a/3rdparty/spirv-tools/source/opt/inline_exhaustive_pass.h b/3rdparty/spirv-tools/source/opt/inline_exhaustive_pass.h index 08b4387b5..103e091e0 100644 --- a/3rdparty/spirv-tools/source/opt/inline_exhaustive_pass.h +++ b/3rdparty/spirv-tools/source/opt/inline_exhaustive_pass.h @@ -14,8 +14,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_INLINE_EXHAUSTIVE_PASS_H_ -#define LIBSPIRV_OPT_INLINE_EXHAUSTIVE_PASS_H_ +#ifndef SOURCE_OPT_INLINE_EXHAUSTIVE_PASS_H_ +#define SOURCE_OPT_INLINE_EXHAUSTIVE_PASS_H_ #include #include @@ -23,9 +23,9 @@ #include #include -#include "def_use_manager.h" -#include "inline_pass.h" -#include "module.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/inline_pass.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { @@ -34,20 +34,20 @@ namespace opt { class InlineExhaustivePass : public InlinePass { public: InlineExhaustivePass(); - Status Process(ir::IRContext* c) override; + Status Process() override; const char* name() const override { return "inline-entry-points-exhaustive"; } private: // Exhaustively inline all function calls in func as well as in // all code that is inlined into func. Return true if func is modified. - bool InlineExhaustive(ir::Function* func); + bool InlineExhaustive(Function* func); - void Initialize(ir::IRContext* c); + void Initialize(); Pass::Status ProcessImpl(); }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_INLINE_EXHAUSTIVE_PASS_H_ +#endif // SOURCE_OPT_INLINE_EXHAUSTIVE_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/inline_opaque_pass.cpp b/3rdparty/spirv-tools/source/opt/inline_opaque_pass.cpp index 13880502e..c2c3719fe 100644 --- a/3rdparty/spirv-tools/source/opt/inline_opaque_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/inline_opaque_pass.cpp @@ -14,11 +14,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "inline_opaque_pass.h" +#include "source/opt/inline_opaque_pass.h" + +#include namespace spvtools { namespace opt { - namespace { const uint32_t kTypePointerTypeIdInIdx = 1; @@ -26,7 +27,7 @@ const uint32_t kTypePointerTypeIdInIdx = 1; } // anonymous namespace bool InlineOpaquePass::IsOpaqueType(uint32_t typeId) { - const ir::Instruction* typeInst = get_def_use_mgr()->GetDef(typeId); + const Instruction* typeInst = get_def_use_mgr()->GetDef(typeId); switch (typeInst->opcode()) { case SpvOpTypeSampler: case SpvOpTypeImage: @@ -47,14 +48,14 @@ bool InlineOpaquePass::IsOpaqueType(uint32_t typeId) { }); } -bool InlineOpaquePass::HasOpaqueArgsOrReturn(const ir::Instruction* callInst) { +bool InlineOpaquePass::HasOpaqueArgsOrReturn(const Instruction* callInst) { // Check return type if (IsOpaqueType(callInst->type_id())) return true; // Check args int icnt = 0; return !callInst->WhileEachInId([&icnt, this](const uint32_t* iid) { if (icnt > 0) { - const ir::Instruction* argInst = get_def_use_mgr()->GetDef(*iid); + const Instruction* argInst = get_def_use_mgr()->GetDef(*iid); if (IsOpaqueType(argInst->type_id())) return false; } ++icnt; @@ -62,15 +63,15 @@ bool InlineOpaquePass::HasOpaqueArgsOrReturn(const ir::Instruction* callInst) { }); } -bool InlineOpaquePass::InlineOpaque(ir::Function* func) { +bool InlineOpaquePass::InlineOpaque(Function* func) { bool modified = false; // Using block iterators here because of block erasures and insertions. for (auto bi = func->begin(); bi != func->end(); ++bi) { for (auto ii = bi->begin(); ii != bi->end();) { if (IsInlinableFunctionCall(&*ii) && HasOpaqueArgsOrReturn(&*ii)) { // Inline call. - std::vector> newBlocks; - std::vector> newVars; + std::vector> newBlocks; + std::vector> newVars; GenInlineCode(&newBlocks, &newVars, ii, bi); // If call block is replaced with more than one block, point // succeeding phis at new last block. @@ -92,19 +93,19 @@ bool InlineOpaquePass::InlineOpaque(ir::Function* func) { return modified; } -void InlineOpaquePass::Initialize(ir::IRContext* c) { InitializeInline(c); } +void InlineOpaquePass::Initialize() { InitializeInline(); } Pass::Status InlineOpaquePass::ProcessImpl() { // Do opaque inlining on each function in entry point call tree - ProcessFunction pfn = [this](ir::Function* fp) { return InlineOpaque(fp); }; + ProcessFunction pfn = [this](Function* fp) { return InlineOpaque(fp); }; bool modified = ProcessEntryPointCallTree(pfn, get_module()); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } -InlineOpaquePass::InlineOpaquePass() {} +InlineOpaquePass::InlineOpaquePass() = default; -Pass::Status InlineOpaquePass::Process(ir::IRContext* c) { - Initialize(c); +Pass::Status InlineOpaquePass::Process() { + Initialize(); return ProcessImpl(); } diff --git a/3rdparty/spirv-tools/source/opt/inline_opaque_pass.h b/3rdparty/spirv-tools/source/opt/inline_opaque_pass.h index 0a6cf0ecb..aad43fd6a 100644 --- a/3rdparty/spirv-tools/source/opt/inline_opaque_pass.h +++ b/3rdparty/spirv-tools/source/opt/inline_opaque_pass.h @@ -14,8 +14,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_INLINE_OPAQUE_PASS_H_ -#define LIBSPIRV_OPT_INLINE_OPAQUE_PASS_H_ +#ifndef SOURCE_OPT_INLINE_OPAQUE_PASS_H_ +#define SOURCE_OPT_INLINE_OPAQUE_PASS_H_ #include #include @@ -23,9 +23,9 @@ #include #include -#include "def_use_manager.h" -#include "inline_pass.h" -#include "module.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/inline_pass.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { @@ -34,7 +34,7 @@ namespace opt { class InlineOpaquePass : public InlinePass { public: InlineOpaquePass(); - Status Process(ir::IRContext* c) override; + Status Process() override; const char* name() const override { return "inline-entry-points-opaque"; } @@ -43,18 +43,18 @@ class InlineOpaquePass : public InlinePass { bool IsOpaqueType(uint32_t typeId); // Return true if function call |callInst| has opaque argument or return type - bool HasOpaqueArgsOrReturn(const ir::Instruction* callInst); + bool HasOpaqueArgsOrReturn(const Instruction* callInst); // Inline all function calls in |func| that have opaque params or return // type. Inline similarly all code that is inlined into func. Return true // if func is modified. - bool InlineOpaque(ir::Function* func); + bool InlineOpaque(Function* func); - void Initialize(ir::IRContext* c); + void Initialize(); Pass::Status ProcessImpl(); }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_INLINE_OPAQUE_PASS_H_ +#endif // SOURCE_OPT_INLINE_OPAQUE_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/inline_pass.cpp b/3rdparty/spirv-tools/source/opt/inline_pass.cpp index 61c734329..5a88ef5d3 100644 --- a/3rdparty/spirv-tools/source/opt/inline_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/inline_pass.cpp @@ -14,9 +14,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "inline_pass.h" +#include "source/opt/inline_pass.h" -#include "cfa.h" +#include +#include + +#include "source/cfa.h" +#include "source/util/make_unique.h" // Indices of operands in SPIR-V instructions @@ -32,11 +36,11 @@ namespace opt { uint32_t InlinePass::AddPointerToType(uint32_t type_id, SpvStorageClass storage_class) { uint32_t resultId = TakeNextId(); - std::unique_ptr type_inst(new ir::Instruction( - context(), SpvOpTypePointer, 0, resultId, - {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, - {uint32_t(storage_class)}}, - {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}})); + std::unique_ptr type_inst( + new Instruction(context(), SpvOpTypePointer, 0, resultId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, + {uint32_t(storage_class)}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}})); context()->AddType(std::move(type_inst)); analysis::Type* pointeeTy; std::unique_ptr pointerTy; @@ -48,27 +52,27 @@ uint32_t InlinePass::AddPointerToType(uint32_t type_id, } void InlinePass::AddBranch(uint32_t label_id, - std::unique_ptr* block_ptr) { - std::unique_ptr newBranch(new ir::Instruction( - context(), SpvOpBranch, 0, 0, - {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {label_id}}})); + std::unique_ptr* block_ptr) { + std::unique_ptr newBranch( + new Instruction(context(), SpvOpBranch, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {label_id}}})); (*block_ptr)->AddInstruction(std::move(newBranch)); } void InlinePass::AddBranchCond(uint32_t cond_id, uint32_t true_id, uint32_t false_id, - std::unique_ptr* block_ptr) { - std::unique_ptr newBranch(new ir::Instruction( - context(), SpvOpBranchConditional, 0, 0, - {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cond_id}}, - {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {true_id}}, - {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {false_id}}})); + std::unique_ptr* block_ptr) { + std::unique_ptr newBranch( + new Instruction(context(), SpvOpBranchConditional, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cond_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {true_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {false_id}}})); (*block_ptr)->AddInstruction(std::move(newBranch)); } void InlinePass::AddLoopMerge(uint32_t merge_id, uint32_t continue_id, - std::unique_ptr* block_ptr) { - std::unique_ptr newLoopMerge(new ir::Instruction( + std::unique_ptr* block_ptr) { + std::unique_ptr newLoopMerge(new Instruction( context(), SpvOpLoopMerge, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {merge_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {continue_id}}, @@ -77,25 +81,25 @@ void InlinePass::AddLoopMerge(uint32_t merge_id, uint32_t continue_id, } void InlinePass::AddStore(uint32_t ptr_id, uint32_t val_id, - std::unique_ptr* block_ptr) { - std::unique_ptr newStore(new ir::Instruction( - context(), SpvOpStore, 0, 0, - {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}, - {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {val_id}}})); + std::unique_ptr* block_ptr) { + std::unique_ptr newStore( + new Instruction(context(), SpvOpStore, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {val_id}}})); (*block_ptr)->AddInstruction(std::move(newStore)); } void InlinePass::AddLoad(uint32_t type_id, uint32_t resultId, uint32_t ptr_id, - std::unique_ptr* block_ptr) { - std::unique_ptr newLoad(new ir::Instruction( - context(), SpvOpLoad, type_id, resultId, - {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}})); + std::unique_ptr* block_ptr) { + std::unique_ptr newLoad( + new Instruction(context(), SpvOpLoad, type_id, resultId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}})); (*block_ptr)->AddInstruction(std::move(newLoad)); } -std::unique_ptr InlinePass::NewLabel(uint32_t label_id) { - std::unique_ptr newLabel( - new ir::Instruction(context(), SpvOpLabel, 0, label_id, {})); +std::unique_ptr InlinePass::NewLabel(uint32_t label_id) { + std::unique_ptr newLabel( + new Instruction(context(), SpvOpLabel, 0, label_id, {})); return newLabel; } @@ -114,27 +118,25 @@ uint32_t InlinePass::GetFalseId() { } void InlinePass::MapParams( - ir::Function* calleeFn, ir::BasicBlock::iterator call_inst_itr, + Function* calleeFn, BasicBlock::iterator call_inst_itr, std::unordered_map* callee2caller) { int param_idx = 0; - calleeFn->ForEachParam( - [&call_inst_itr, ¶m_idx, &callee2caller](const ir::Instruction* cpi) { - const uint32_t pid = cpi->result_id(); - (*callee2caller)[pid] = call_inst_itr->GetSingleWordOperand( - kSpvFunctionCallArgumentId + param_idx); - ++param_idx; - }); + calleeFn->ForEachParam([&call_inst_itr, ¶m_idx, + &callee2caller](const Instruction* cpi) { + const uint32_t pid = cpi->result_id(); + (*callee2caller)[pid] = call_inst_itr->GetSingleWordOperand( + kSpvFunctionCallArgumentId + param_idx); + ++param_idx; + }); } void InlinePass::CloneAndMapLocals( - ir::Function* calleeFn, - std::vector>* new_vars, + Function* calleeFn, std::vector>* new_vars, std::unordered_map* callee2caller) { auto callee_block_itr = calleeFn->begin(); auto callee_var_itr = callee_block_itr->begin(); while (callee_var_itr->opcode() == SpvOp::SpvOpVariable) { - std::unique_ptr var_inst( - callee_var_itr->Clone(callee_var_itr->context())); + std::unique_ptr var_inst(callee_var_itr->Clone(context())); uint32_t newId = TakeNextId(); get_decoration_mgr()->CloneDecorations(callee_var_itr->result_id(), newId); var_inst->SetResultId(newId); @@ -145,8 +147,7 @@ void InlinePass::CloneAndMapLocals( } uint32_t InlinePass::CreateReturnVar( - ir::Function* calleeFn, - std::vector>* new_vars) { + Function* calleeFn, std::vector>* new_vars) { uint32_t returnVarId = 0; const uint32_t calleeTypeId = calleeFn->type_id(); analysis::Type* calleeType = context()->get_type_mgr()->GetType(calleeTypeId); @@ -158,25 +159,25 @@ uint32_t InlinePass::CreateReturnVar( returnVarTypeId = AddPointerToType(calleeTypeId, SpvStorageClassFunction); // Add return var to new function scope variables. returnVarId = TakeNextId(); - std::unique_ptr var_inst(new ir::Instruction( - context(), SpvOpVariable, returnVarTypeId, returnVarId, - {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, - {SpvStorageClassFunction}}})); + std::unique_ptr var_inst( + new Instruction(context(), SpvOpVariable, returnVarTypeId, returnVarId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, + {SpvStorageClassFunction}}})); new_vars->push_back(std::move(var_inst)); } get_decoration_mgr()->CloneDecorations(calleeFn->result_id(), returnVarId); return returnVarId; } -bool InlinePass::IsSameBlockOp(const ir::Instruction* inst) const { +bool InlinePass::IsSameBlockOp(const Instruction* inst) const { return inst->opcode() == SpvOpSampledImage || inst->opcode() == SpvOpImage; } void InlinePass::CloneSameBlockOps( - std::unique_ptr* inst, + std::unique_ptr* inst, std::unordered_map* postCallSB, - std::unordered_map* preCallSB, - std::unique_ptr* block_ptr) { + std::unordered_map* preCallSB, + std::unique_ptr* block_ptr) { (*inst)->ForEachInId( [&postCallSB, &preCallSB, &block_ptr, this](uint32_t* iid) { const auto mapItr = (*postCallSB).find(*iid); @@ -184,9 +185,8 @@ void InlinePass::CloneSameBlockOps( const auto mapItr2 = (*preCallSB).find(*iid); if (mapItr2 != (*preCallSB).end()) { // Clone pre-call same-block ops, map result id. - const ir::Instruction* inInst = mapItr2->second; - std::unique_ptr sb_inst( - inInst->Clone(inInst->context())); + const Instruction* inInst = mapItr2->second; + std::unique_ptr sb_inst(inInst->Clone(context())); CloneSameBlockOps(&sb_inst, postCallSB, preCallSB, block_ptr); const uint32_t rid = sb_inst->result_id(); const uint32_t nid = this->TakeNextId(); @@ -204,24 +204,24 @@ void InlinePass::CloneSameBlockOps( } void InlinePass::GenInlineCode( - std::vector>* new_blocks, - std::vector>* new_vars, - ir::BasicBlock::iterator call_inst_itr, - ir::UptrVectorIterator call_block_itr) { + std::vector>* new_blocks, + std::vector>* new_vars, + BasicBlock::iterator call_inst_itr, + UptrVectorIterator call_block_itr) { // Map from all ids in the callee to their equivalent id in the caller // as callee instructions are copied into caller. std::unordered_map callee2caller; // Pre-call same-block insts - std::unordered_map preCallSB; + std::unordered_map preCallSB; // Post-call same-block op ids std::unordered_map postCallSB; // Invalidate the def-use chains. They are not kept up to date while // inlining. However, certain calls try to keep them up-to-date if they are // valid. These operations can fail. - context()->InvalidateAnalyses(ir::IRContext::kAnalysisDefUse); + context()->InvalidateAnalyses(IRContext::kAnalysisDefUse); - ir::Function* calleeFn = id2function_[call_inst_itr->GetSingleWordOperand( + Function* calleeFn = id2function_[call_inst_itr->GetSingleWordOperand( kSpvFunctionCallFunctionId)]; // Check for multiple returns in the callee. @@ -240,7 +240,7 @@ void InlinePass::GenInlineCode( // Create set of callee result ids. Used to detect forward references std::unordered_set callee_result_ids; - calleeFn->ForEachInst([&callee_result_ids](const ir::Instruction* cpi) { + calleeFn->ForEachInst([&callee_result_ids](const Instruction* cpi) { const uint32_t rid = cpi->result_id(); if (rid != 0) callee_result_ids.insert(rid); }); @@ -275,20 +275,41 @@ void InlinePass::GenInlineCode( // written to it. It is created when we encounter the OpLabel // of the first callee block. It is appended to new_blocks only when // it is complete. - std::unique_ptr new_blk_ptr; + std::unique_ptr new_blk_ptr; calleeFn->ForEachInst([&new_blocks, &callee2caller, &call_block_itr, &call_inst_itr, &new_blk_ptr, &prevInstWasReturn, &returnLabelId, &returnVarId, caller_is_loop_header, callee_begins_with_structured_header, &calleeTypeId, &multiBlocks, &postCallSB, &preCallSB, multiReturn, &singleTripLoopHeaderId, &singleTripLoopContinueId, - &callee_result_ids, this](const ir::Instruction* cpi) { + &callee_result_ids, this](const Instruction* cpi) { switch (cpi->opcode()) { case SpvOpFunction: case SpvOpFunctionParameter: - case SpvOpVariable: // Already processed break; + case SpvOpVariable: + if (cpi->NumInOperands() == 2) { + assert(callee2caller.count(cpi->result_id()) && + "Expected the variable to have already been mapped."); + uint32_t new_var_id = callee2caller.at(cpi->result_id()); + + // The initializer must be a constant or global value. No mapped + // should be used. + uint32_t val_id = cpi->GetSingleWordInOperand(1); + AddStore(new_var_id, val_id, &new_blk_ptr); + } + break; + case SpvOpUnreachable: + case SpvOpKill: { + // Generate a return label so that we split the block with the function + // call. Copy the terminator into the new block. + if (returnLabelId == 0) returnLabelId = this->TakeNextId(); + std::unique_ptr terminator( + new Instruction(context(), cpi->opcode(), 0, 0, {})); + new_blk_ptr->AddInstruction(std::move(terminator)); + break; + } case SpvOpLabel: { // If previous instruction was early return, insert branch // instruction to return block. @@ -316,14 +337,14 @@ void InlinePass::GenInlineCode( firstBlock = true; } // Create first/next block. - new_blk_ptr.reset(new ir::BasicBlock(NewLabel(labelId))); + new_blk_ptr = MakeUnique(NewLabel(labelId)); if (firstBlock) { // Copy contents of original caller block up to call instruction. for (auto cii = call_block_itr->begin(); cii != call_inst_itr; cii = call_block_itr->begin()) { - ir::Instruction* inst = &*cii; + Instruction* inst = &*cii; inst->RemoveFromList(); - std::unique_ptr cp_inst(inst); + std::unique_ptr cp_inst(inst); // Remember same-block ops for possible regeneration. if (IsSameBlockOp(&*cp_inst)) { auto* sb_inst_ptr = cp_inst.get(); @@ -342,7 +363,7 @@ void InlinePass::GenInlineCode( AddBranch(guard_block_id, &new_blk_ptr); new_blocks->push_back(std::move(new_blk_ptr)); // Start the next block. - new_blk_ptr.reset(new ir::BasicBlock(NewLabel(guard_block_id))); + new_blk_ptr = MakeUnique(NewLabel(guard_block_id)); // Reset the mapping of the callee's entry block to point to // the guard block. Do this so we can fix up phis later on to // satisfy dominance. @@ -363,15 +384,15 @@ void InlinePass::GenInlineCode( singleTripLoopHeaderId = this->TakeNextId(); AddBranch(singleTripLoopHeaderId, &new_blk_ptr); new_blocks->push_back(std::move(new_blk_ptr)); - new_blk_ptr.reset( - new ir::BasicBlock(NewLabel(singleTripLoopHeaderId))); + new_blk_ptr = + MakeUnique(NewLabel(singleTripLoopHeaderId)); returnLabelId = this->TakeNextId(); singleTripLoopContinueId = this->TakeNextId(); AddLoopMerge(returnLabelId, singleTripLoopContinueId, &new_blk_ptr); uint32_t postHeaderId = this->TakeNextId(); AddBranch(postHeaderId, &new_blk_ptr); new_blocks->push_back(std::move(new_blk_ptr)); - new_blk_ptr.reset(new ir::BasicBlock(NewLabel(postHeaderId))); + new_blk_ptr = MakeUnique(NewLabel(postHeaderId)); multiBlocks = true; // Reset the mapping of the callee's entry block to point to // the post-header block. Do this so we can fix up phis later @@ -413,14 +434,14 @@ void InlinePass::GenInlineCode( // to accommodate multiple returns, insert the continue // target block now, with a false branch back to the loop header. new_blocks->push_back(std::move(new_blk_ptr)); - new_blk_ptr.reset( - new ir::BasicBlock(NewLabel(singleTripLoopContinueId))); + new_blk_ptr = + MakeUnique(NewLabel(singleTripLoopContinueId)); AddBranchCond(GetFalseId(), singleTripLoopHeaderId, returnLabelId, &new_blk_ptr); } // Generate the return block. new_blocks->push_back(std::move(new_blk_ptr)); - new_blk_ptr.reset(new ir::BasicBlock(NewLabel(returnLabelId))); + new_blk_ptr = MakeUnique(NewLabel(returnLabelId)); multiBlocks = true; } // Load return value into result id of call, if it exists. @@ -430,10 +451,10 @@ void InlinePass::GenInlineCode( AddLoad(calleeTypeId, resId, returnVarId, &new_blk_ptr); } // Copy remaining instructions from caller block. - for (ir::Instruction* inst = call_inst_itr->NextNode(); inst; + for (Instruction* inst = call_inst_itr->NextNode(); inst; inst = call_inst_itr->NextNode()) { inst->RemoveFromList(); - std::unique_ptr cp_inst(inst); + std::unique_ptr cp_inst(inst); // If multiple blocks generated, regenerate any same-block // instruction that has not been seen in this last block. if (multiBlocks) { @@ -451,7 +472,7 @@ void InlinePass::GenInlineCode( } break; default: { // Copy callee instruction and remap all input Ids. - std::unique_ptr cp_inst(cpi->Clone(context())); + std::unique_ptr cp_inst(cpi->Clone(context())); cp_inst->ForEachInId([&callee2caller, &callee_result_ids, this](uint32_t* iid) { const auto mapItr = callee2caller.find(*iid); @@ -496,7 +517,7 @@ void InlinePass::GenInlineCode( auto loop_merge_itr = last->tail(); --loop_merge_itr; assert(loop_merge_itr->opcode() == SpvOpLoopMerge); - std::unique_ptr cp_inst(loop_merge_itr->Clone(context())); + std::unique_ptr cp_inst(loop_merge_itr->Clone(context())); if (caller_is_single_block_loop) { // Also, update its continue target to point to the last block. cp_inst->SetInOperand(kSpvLoopMergeContinueTargetIdInIdx, {last->id()}); @@ -514,7 +535,7 @@ void InlinePass::GenInlineCode( } } -bool InlinePass::IsInlinableFunctionCall(const ir::Instruction* inst) { +bool InlinePass::IsInlinableFunctionCall(const Instruction* inst) { if (inst->opcode() != SpvOp::SpvOpFunctionCall) return false; const uint32_t calleeFnId = inst->GetSingleWordOperand(kSpvFunctionCallFunctionId); @@ -523,16 +544,16 @@ bool InlinePass::IsInlinableFunctionCall(const ir::Instruction* inst) { } void InlinePass::UpdateSucceedingPhis( - std::vector>& new_blocks) { + std::vector>& new_blocks) { const auto firstBlk = new_blocks.begin(); const auto lastBlk = new_blocks.end() - 1; const uint32_t firstId = (*firstBlk)->id(); const uint32_t lastId = (*lastBlk)->id(); - const ir::BasicBlock& const_last_block = *lastBlk->get(); + const BasicBlock& const_last_block = *lastBlk->get(); const_last_block.ForEachSuccessorLabel( [&firstId, &lastId, this](const uint32_t succ) { - ir::BasicBlock* sbp = this->id2block_[succ]; - sbp->ForEachPhiInst([&firstId, &lastId](ir::Instruction* phi) { + BasicBlock* sbp = this->id2block_[succ]; + sbp->ForEachPhiInst([&firstId, &lastId](Instruction* phi) { phi->ForEachInId([&firstId, &lastId](uint32_t* id) { if (*id == firstId) *id = lastId; }); @@ -540,7 +561,7 @@ void InlinePass::UpdateSucceedingPhis( }); } -bool InlinePass::HasMultipleReturns(ir::Function* func) { +bool InlinePass::HasMultipleReturns(Function* func) { bool seenReturn = false; bool multipleReturns = false; for (auto& blk : *func) { @@ -558,7 +579,7 @@ bool InlinePass::HasMultipleReturns(ir::Function* func) { return multipleReturns; } -void InlinePass::ComputeStructuredSuccessors(ir::Function* func) { +void InlinePass::ComputeStructuredSuccessors(Function* func) { // If header, make merge block first successor. for (auto& blk : *func) { uint32_t mbid = blk.MergeBlockIdIfAny(); @@ -575,12 +596,12 @@ void InlinePass::ComputeStructuredSuccessors(ir::Function* func) { } InlinePass::GetBlocksFunction InlinePass::StructuredSuccessorsFunction() { - return [this](const ir::BasicBlock* block) { + return [this](const BasicBlock* block) { return &(block2structured_succs_[block]); }; } -bool InlinePass::HasNoReturnInLoop(ir::Function* func) { +bool InlinePass::HasNoReturnInLoop(Function* func) { // If control not structured, do not do loop/return analysis // TODO: Analyze returns in non-structured control flow if (!context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) @@ -591,8 +612,8 @@ bool InlinePass::HasNoReturnInLoop(ir::Function* func) { ComputeStructuredSuccessors(func); auto ignore_block = [](cbb_ptr) {}; auto ignore_edge = [](cbb_ptr, cbb_ptr) {}; - std::list structuredOrder; - spvtools::CFA::DepthFirstTraversal( + std::list structuredOrder; + CFA::DepthFirstTraversal( &*func->begin(), StructuredSuccessorsFunction(), ignore_block, [&](cbb_ptr b) { structuredOrder.push_front(b); }, ignore_edge); // Search for returns in loops. Only need to track outermost loop @@ -622,7 +643,7 @@ bool InlinePass::HasNoReturnInLoop(ir::Function* func) { return !return_in_loop; } -void InlinePass::AnalyzeReturns(ir::Function* func) { +void InlinePass::AnalyzeReturns(Function* func) { // Look for multiple returns if (!HasMultipleReturns(func)) { no_return_in_loop_.insert(func->result_id()); @@ -633,7 +654,7 @@ void InlinePass::AnalyzeReturns(ir::Function* func) { if (HasNoReturnInLoop(func)) no_return_in_loop_.insert(func->result_id()); } -bool InlinePass::IsInlinableFunction(ir::Function* func) { +bool InlinePass::IsInlinableFunction(Function* func) { // We can only inline a function if it has blocks. if (func->cbegin() == func->cend()) return false; // Do not inline functions with returns in loops. Currently early return @@ -646,9 +667,7 @@ bool InlinePass::IsInlinableFunction(ir::Function* func) { no_return_in_loop_.cend(); } -void InlinePass::InitializeInline(ir::IRContext* c) { - InitializeProcessing(c); - +void InlinePass::InitializeInline() { false_id_ = 0; // clear collections diff --git a/3rdparty/spirv-tools/source/opt/inline_pass.h b/3rdparty/spirv-tools/source/opt/inline_pass.h index dd2d0c7d0..55369c98c 100644 --- a/3rdparty/spirv-tools/source/opt/inline_pass.h +++ b/3rdparty/spirv-tools/source/opt/inline_pass.h @@ -14,93 +14,91 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_INLINE_PASS_H_ -#define LIBSPIRV_OPT_INLINE_PASS_H_ +#ifndef SOURCE_OPT_INLINE_PASS_H_ +#define SOURCE_OPT_INLINE_PASS_H_ #include #include #include +#include #include #include -#include "decoration_manager.h" -#include "module.h" -#include "pass.h" +#include "source/opt/decoration_manager.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { // See optimizer.hpp for documentation. class InlinePass : public Pass { - using cbb_ptr = const ir::BasicBlock*; + using cbb_ptr = const BasicBlock*; public: using GetBlocksFunction = - std::function*(const ir::BasicBlock*)>; + std::function*(const BasicBlock*)>; - InlinePass(); virtual ~InlinePass() = default; protected: + InlinePass(); // Add pointer to type to module and return resultId. uint32_t AddPointerToType(uint32_t type_id, SpvStorageClass storage_class); // Add unconditional branch to labelId to end of block block_ptr. - void AddBranch(uint32_t labelId, std::unique_ptr* block_ptr); + void AddBranch(uint32_t labelId, std::unique_ptr* block_ptr); // Add conditional branch to end of block |block_ptr|. void AddBranchCond(uint32_t cond_id, uint32_t true_id, uint32_t false_id, - std::unique_ptr* block_ptr); + std::unique_ptr* block_ptr); // Add unconditional branch to labelId to end of block block_ptr. void AddLoopMerge(uint32_t merge_id, uint32_t continue_id, - std::unique_ptr* block_ptr); + std::unique_ptr* block_ptr); // Add store of valId to ptrId to end of block block_ptr. void AddStore(uint32_t ptrId, uint32_t valId, - std::unique_ptr* block_ptr); + std::unique_ptr* block_ptr); // Add load of ptrId into resultId to end of block block_ptr. void AddLoad(uint32_t typeId, uint32_t resultId, uint32_t ptrId, - std::unique_ptr* block_ptr); + std::unique_ptr* block_ptr); // Return new label. - std::unique_ptr NewLabel(uint32_t label_id); + std::unique_ptr NewLabel(uint32_t label_id); // Returns the id for the boolean false value. Looks in the module first // and creates it if not found. Remembers it for future calls. uint32_t GetFalseId(); // Map callee params to caller args - void MapParams(ir::Function* calleeFn, ir::BasicBlock::iterator call_inst_itr, + void MapParams(Function* calleeFn, BasicBlock::iterator call_inst_itr, std::unordered_map* callee2caller); // Clone and map callee locals - void CloneAndMapLocals( - ir::Function* calleeFn, - std::vector>* new_vars, - std::unordered_map* callee2caller); + void CloneAndMapLocals(Function* calleeFn, + std::vector>* new_vars, + std::unordered_map* callee2caller); // Create return variable for callee clone code if needed. Return id // if created, otherwise 0. - uint32_t CreateReturnVar( - ir::Function* calleeFn, - std::vector>* new_vars); + uint32_t CreateReturnVar(Function* calleeFn, + std::vector>* new_vars); // Return true if instruction must be in the same block that its result // is used. - bool IsSameBlockOp(const ir::Instruction* inst) const; + bool IsSameBlockOp(const Instruction* inst) const; // Clone operands which must be in same block as consumer instructions. // Look in preCallSB for instructions that need cloning. Look in // postCallSB for instructions already cloned. Add cloned instruction // to postCallSB. - void CloneSameBlockOps( - std::unique_ptr* inst, - std::unordered_map* postCallSB, - std::unordered_map* preCallSB, - std::unique_ptr* block_ptr); + void CloneSameBlockOps(std::unique_ptr* inst, + std::unordered_map* postCallSB, + std::unordered_map* preCallSB, + std::unique_ptr* block_ptr); // Return in new_blocks the result of inlining the call at call_inst_itr // within its block at call_block_itr. The block at call_block_itr can @@ -116,13 +114,13 @@ class InlinePass : public Pass { // Also return in new_vars additional OpVariable instructions required by // and to be inserted into the caller function after the block at // call_block_itr is replaced with new_blocks. - void GenInlineCode(std::vector>* new_blocks, - std::vector>* new_vars, - ir::BasicBlock::iterator call_inst_itr, - ir::UptrVectorIterator call_block_itr); + void GenInlineCode(std::vector>* new_blocks, + std::vector>* new_vars, + BasicBlock::iterator call_inst_itr, + UptrVectorIterator call_block_itr); // Return true if |inst| is a function call that can be inlined. - bool IsInlinableFunctionCall(const ir::Instruction* inst); + bool IsInlinableFunctionCall(const Instruction* inst); // Compute structured successors for function |func|. // A block's structured successors are the blocks it branches to @@ -131,39 +129,39 @@ class InlinePass : public Pass { // This assures correct depth first search in the presence of early // returns and kills. If the successor vector contain duplicates // if the merge block, they are safely ignored by DFS. - void ComputeStructuredSuccessors(ir::Function* func); + void ComputeStructuredSuccessors(Function* func); // Return function to return ordered structure successors for a given block // Assumes ComputeStructuredSuccessors() has been called. GetBlocksFunction StructuredSuccessorsFunction(); // Return true if |func| has multiple returns - bool HasMultipleReturns(ir::Function* func); + bool HasMultipleReturns(Function* func); // Return true if |func| has no return in a loop. The current analysis // requires structured control flow, so return false if control flow not // structured ie. module is not a shader. - bool HasNoReturnInLoop(ir::Function* func); + bool HasNoReturnInLoop(Function* func); // Find all functions with multiple returns and no returns in loops - void AnalyzeReturns(ir::Function* func); + void AnalyzeReturns(Function* func); // Return true if |func| is a function that can be inlined. - bool IsInlinableFunction(ir::Function* func); + bool IsInlinableFunction(Function* func); // Update phis in succeeding blocks to point to new last block void UpdateSucceedingPhis( - std::vector>& new_blocks); + std::vector>& new_blocks); // Initialize state for optimization of |module| - void InitializeInline(ir::IRContext* c); + void InitializeInline(); // Map from function's result id to function. - std::unordered_map id2function_; + std::unordered_map id2function_; // Map from block's label id to block. TODO(dnovillo): This is superfluous wrt - // opt::CFG. It has functionality not present in opt::CFG. Consolidate. - std::unordered_map id2block_; + // CFG. It has functionality not present in CFG. Consolidate. + std::unordered_map id2block_; // Set of ids of functions with multiple returns. std::set multi_return_funcs_; @@ -179,13 +177,13 @@ class InlinePass : public Pass { // Map from block to its structured successor blocks. See // ComputeStructuredSuccessors() for definition. TODO(dnovillo): This is - // superfluous wrt opt::CFG, but it seems to be computed in a slightly + // superfluous wrt CFG, but it seems to be computed in a slightly // different way in the inliner. Can these be consolidated? - std::unordered_map> + std::unordered_map> block2structured_succs_; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_INLINE_PASS_H_ +#endif // SOURCE_OPT_INLINE_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/insert_extract_elim.cpp b/3rdparty/spirv-tools/source/opt/insert_extract_elim.cpp deleted file mode 100644 index 2d0d8a24b..000000000 --- a/3rdparty/spirv-tools/source/opt/insert_extract_elim.cpp +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright (c) 2017 The Khronos Group Inc. -// Copyright (c) 2017 Valve Corporation -// Copyright (c) 2017 LunarG Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "insert_extract_elim.h" - -#include "composite.h" -#include "ir_context.h" -#include "iterator.h" -#include "latest_version_glsl_std_450_header.h" - -#include - -namespace spvtools { -namespace opt { - -namespace { - -const uint32_t kConstantValueInIdx = 0; -const uint32_t kExtractCompositeIdInIdx = 0; -const uint32_t kInsertObjectIdInIdx = 0; -const uint32_t kInsertCompositeIdInIdx = 1; -const uint32_t kVectorShuffleVec1IdInIdx = 0; -const uint32_t kVectorShuffleVec2IdInIdx = 1; -const uint32_t kVectorShuffleCompsInIdx = 2; -const uint32_t kTypeVectorCompTypeIdInIdx = 0; -const uint32_t kTypeVectorLengthInIdx = 1; -const uint32_t kTypeFloatWidthInIdx = 0; -const uint32_t kExtInstSetIdInIdx = 0; -const uint32_t kExtInstInstructionInIdx = 1; -const uint32_t kFMixXIdInIdx = 2; -const uint32_t kFMixYIdInIdx = 3; -const uint32_t kFMixAIdInIdx = 4; - -} // anonymous namespace - -uint32_t InsertExtractElimPass::DoExtract(ir::Instruction* compInst, - std::vector* pExtIndices, - uint32_t extOffset) { - ir::Instruction* cinst = compInst; - uint32_t cid = 0; - uint32_t replId = 0; - while (true) { - if (cinst->opcode() == SpvOpCompositeInsert) { - if (ExtInsMatch(*pExtIndices, cinst, extOffset)) { - // Match! Use inserted value as replacement - replId = cinst->GetSingleWordInOperand(kInsertObjectIdInIdx); - break; - } else if (ExtInsConflict(*pExtIndices, cinst, extOffset)) { - // If extract has fewer indices than the insert, stop searching. - // Otherwise increment offset of extract indices considered and - // continue searching through the inserted value - if (pExtIndices->size() - extOffset < cinst->NumInOperands() - 2) { - break; - } else { - extOffset += cinst->NumInOperands() - 2; - cid = cinst->GetSingleWordInOperand(kInsertObjectIdInIdx); - } - } else { - // Consider next composite in insert chain - cid = cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx); - } - } else if (cinst->opcode() == SpvOpVectorShuffle) { - // Get length of vector1 - uint32_t v1_id = cinst->GetSingleWordInOperand(kVectorShuffleVec1IdInIdx); - ir::Instruction* v1_inst = get_def_use_mgr()->GetDef(v1_id); - uint32_t v1_type_id = v1_inst->type_id(); - ir::Instruction* v1_type_inst = get_def_use_mgr()->GetDef(v1_type_id); - uint32_t v1_len = - v1_type_inst->GetSingleWordInOperand(kTypeVectorLengthInIdx); - // Get shuffle idx - uint32_t comp_idx = (*pExtIndices)[extOffset]; - uint32_t shuffle_idx = - cinst->GetSingleWordInOperand(kVectorShuffleCompsInIdx + comp_idx); - // If undefined, give up - // TODO(greg-lunarg): Return OpUndef - if (shuffle_idx == 0xFFFFFFFF) break; - if (shuffle_idx < v1_len) { - cid = v1_id; - (*pExtIndices)[extOffset] = shuffle_idx; - } else { - cid = cinst->GetSingleWordInOperand(kVectorShuffleVec2IdInIdx); - (*pExtIndices)[extOffset] = shuffle_idx - v1_len; - } - } else if (cinst->opcode() == SpvOpExtInst && - cinst->GetSingleWordInOperand(kExtInstSetIdInIdx) == - get_feature_mgr()->GetExtInstImportId_GLSLstd450() && - cinst->GetSingleWordInOperand(kExtInstInstructionInIdx) == - GLSLstd450FMix) { - // If mixing value component is 0 or 1 we just match with x or y. - // Otherwise give up. - uint32_t comp_idx = (*pExtIndices)[extOffset]; - std::vector aIndices = {comp_idx}; - uint32_t a_id = cinst->GetSingleWordInOperand(kFMixAIdInIdx); - ir::Instruction* a_inst = get_def_use_mgr()->GetDef(a_id); - uint32_t a_comp_id = DoExtract(a_inst, &aIndices, 0); - if (a_comp_id == 0) break; - ir::Instruction* a_comp_inst = get_def_use_mgr()->GetDef(a_comp_id); - if (a_comp_inst->opcode() != SpvOpConstant) break; - // If a value is not 32-bit, give up - uint32_t a_comp_type_id = a_comp_inst->type_id(); - ir::Instruction* a_comp_type = get_def_use_mgr()->GetDef(a_comp_type_id); - if (a_comp_type->GetSingleWordInOperand(kTypeFloatWidthInIdx) != 32) - break; - uint32_t u = a_comp_inst->GetSingleWordInOperand(kConstantValueInIdx); - float* fp = reinterpret_cast(&u); - if (*fp == 0.0) - cid = cinst->GetSingleWordInOperand(kFMixXIdInIdx); - else if (*fp == 1.0) - cid = cinst->GetSingleWordInOperand(kFMixYIdInIdx); - else - break; - } else { - break; - } - cinst = get_def_use_mgr()->GetDef(cid); - } - // If search ended with CompositeConstruct or ConstantComposite - // and the extract has one index, return the appropriate component. - // TODO(greg-lunarg): Handle multiple-indices, ConstantNull, special - // vector composition, and additional CompositeInsert. - if (replId == 0 && - (cinst->opcode() == SpvOpCompositeConstruct || - cinst->opcode() == SpvOpConstantComposite) && - (*pExtIndices).size() - extOffset == 1) { - uint32_t compIdx = (*pExtIndices)[extOffset]; - // If a vector CompositeConstruct we make sure all preceding - // components are of component type (not vector composition). - uint32_t ctype_id = cinst->type_id(); - ir::Instruction* ctype_inst = get_def_use_mgr()->GetDef(ctype_id); - if (ctype_inst->opcode() == SpvOpTypeVector && - cinst->opcode() == SpvOpConstantComposite) { - uint32_t vec_comp_type_id = - ctype_inst->GetSingleWordInOperand(kTypeVectorCompTypeIdInIdx); - if (compIdx < cinst->NumInOperands()) { - uint32_t i = 0; - for (; i <= compIdx; i++) { - uint32_t compId = cinst->GetSingleWordInOperand(i); - ir::Instruction* componentInst = get_def_use_mgr()->GetDef(compId); - if (componentInst->type_id() != vec_comp_type_id) break; - } - if (i > compIdx) replId = cinst->GetSingleWordInOperand(compIdx); - } - } else { - replId = cinst->GetSingleWordInOperand(compIdx); - } - } - return replId; -} - -bool InsertExtractElimPass::EliminateInsertExtract(ir::Function* func) { - bool modified = false; - for (auto bi = func->begin(); bi != func->end(); ++bi) { - ir::Instruction* inst = &*bi->begin(); - while (inst) { - switch (inst->opcode()) { - case SpvOpCompositeExtract: { - uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); - ir::Instruction* cinst = get_def_use_mgr()->GetDef(cid); - // Capture extract indices - std::vector extIndices; - uint32_t icnt = 0; - inst->ForEachInOperand([&icnt, &extIndices](const uint32_t* idp) { - if (icnt > 0) extIndices.push_back(*idp); - ++icnt; - }); - // Offset of extract indices being compared to insert indices. - // Offset increases as indices are matched. - uint32_t replId = DoExtract(cinst, &extIndices, 0); - if (replId != 0) { - const uint32_t extId = inst->result_id(); - (void)context()->ReplaceAllUsesWith(extId, replId); - inst = context()->KillInst(inst); - modified = true; - } else { - inst = inst->NextNode(); - } - } break; - default: - inst = inst->NextNode(); - break; - } - } - } - return modified; -} - -void InsertExtractElimPass::Initialize(ir::IRContext* c) { - InitializeProcessing(c); -} - -Pass::Status InsertExtractElimPass::ProcessImpl() { - // Process all entry point functions. - ProcessFunction pfn = [this](ir::Function* fp) { - return EliminateInsertExtract(fp); - }; - bool modified = ProcessEntryPointCallTree(pfn, get_module()); - return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; -} - -InsertExtractElimPass::InsertExtractElimPass() {} - -Pass::Status InsertExtractElimPass::Process(ir::IRContext* c) { - Initialize(c); - return ProcessImpl(); -} - -} // namespace opt -} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/insert_extract_elim.h b/3rdparty/spirv-tools/source/opt/insert_extract_elim.h deleted file mode 100644 index 3f1ba00bb..000000000 --- a/3rdparty/spirv-tools/source/opt/insert_extract_elim.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2017 The Khronos Group Inc. -// Copyright (c) 2017 Valve Corporation -// Copyright (c) 2017 LunarG Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef LIBSPIRV_OPT_INSERT_EXTRACT_ELIM_PASS_H_ -#define LIBSPIRV_OPT_INSERT_EXTRACT_ELIM_PASS_H_ - -#include -#include -#include -#include -#include - -#include "basic_block.h" -#include "def_use_manager.h" -#include "ir_context.h" -#include "mem_pass.h" -#include "module.h" - -namespace spvtools { -namespace opt { - -// See optimizer.hpp for documentation. -class InsertExtractElimPass : public MemPass { - public: - InsertExtractElimPass(); - const char* name() const override { return "eliminate-insert-extract"; } - Status Process(ir::IRContext*) override; - - private: - // Return id of component of |cinst| specified by |extIndices| starting with - // index at |extOffset|. Return 0 if indices cannot be matched exactly. - uint32_t DoExtract(ir::Instruction* cinst, std::vector* extIndices, - uint32_t extOffset); - - // Look for OpExtract on sequence of OpInserts in |func|. If there is a - // reaching insert which corresponds to the indices of the extract, replace - // the extract with the value that is inserted. Also resolve extracts from - // CompositeConstruct or ConstantComposite. - bool EliminateInsertExtract(ir::Function* func); - - void Initialize(ir::IRContext* c); - Pass::Status ProcessImpl(); -}; - -} // namespace opt -} // namespace spvtools - -#endif // LIBSPIRV_OPT_INSERT_EXTRACT_ELIM_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/instruction.cpp b/3rdparty/spirv-tools/source/opt/instruction.cpp index ba8413e35..4cfa41d8e 100644 --- a/3rdparty/spirv-tools/source/opt/instruction.cpp +++ b/3rdparty/spirv-tools/source/opt/instruction.cpp @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "instruction.h" +#include "source/opt/instruction.h" #include -#include "disassemble.h" -#include "fold.h" -#include "ir_context.h" -#include "reflect.h" +#include "source/disassemble.h" +#include "source/opt/fold.h" +#include "source/opt/ir_context.h" +#include "source/opt/reflect.h" namespace spvtools { -namespace ir { +namespace opt { namespace { // Indices used to get particular operands out of instructions using InOperand. @@ -36,24 +36,24 @@ Instruction::Instruction(IRContext* c) : utils::IntrusiveNodeBase(), context_(c), opcode_(SpvOpNop), - type_id_(0), - result_id_(0), + has_type_id_(false), + has_result_id_(false), unique_id_(c->TakeNextUniqueId()) {} Instruction::Instruction(IRContext* c, SpvOp op) : utils::IntrusiveNodeBase(), context_(c), opcode_(op), - type_id_(0), - result_id_(0), + has_type_id_(false), + has_result_id_(false), unique_id_(c->TakeNextUniqueId()) {} Instruction::Instruction(IRContext* c, const spv_parsed_instruction_t& inst, std::vector&& dbg_line) : context_(c), opcode_(static_cast(inst.opcode)), - type_id_(inst.type_id), - result_id_(inst.result_id), + has_type_id_(inst.type_id != 0), + has_result_id_(inst.result_id != 0), unique_id_(c->TakeNextUniqueId()), dbg_line_insts_(std::move(dbg_line)) { assert((!IsDebugLineInst(opcode_) || dbg_line.empty()) && @@ -68,22 +68,21 @@ Instruction::Instruction(IRContext* c, const spv_parsed_instruction_t& inst, } Instruction::Instruction(IRContext* c, SpvOp op, uint32_t ty_id, - uint32_t res_id, - const std::vector& in_operands) + uint32_t res_id, const OperandList& in_operands) : utils::IntrusiveNodeBase(), context_(c), opcode_(op), - type_id_(ty_id), - result_id_(res_id), + has_type_id_(ty_id != 0), + has_result_id_(res_id != 0), unique_id_(c->TakeNextUniqueId()), operands_() { - if (type_id_ != 0) { + if (has_type_id_) { operands_.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_TYPE_ID, - std::initializer_list{type_id_}); + std::initializer_list{ty_id}); } - if (result_id_ != 0) { + if (has_result_id_) { operands_.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_RESULT_ID, - std::initializer_list{result_id_}); + std::initializer_list{res_id}); } operands_.insert(operands_.end(), in_operands.begin(), in_operands.end()); } @@ -91,16 +90,16 @@ Instruction::Instruction(IRContext* c, SpvOp op, uint32_t ty_id, Instruction::Instruction(Instruction&& that) : utils::IntrusiveNodeBase(), opcode_(that.opcode_), - type_id_(that.type_id_), - result_id_(that.result_id_), + has_type_id_(that.has_type_id_), + has_result_id_(that.has_result_id_), unique_id_(that.unique_id_), operands_(std::move(that.operands_)), dbg_line_insts_(std::move(that.dbg_line_insts_)) {} Instruction& Instruction::operator=(Instruction&& that) { opcode_ = that.opcode_; - type_id_ = that.type_id_; - result_id_ = that.result_id_; + has_type_id_ = that.has_type_id_; + has_result_id_ = that.has_result_id_; unique_id_ = that.unique_id_; operands_ = std::move(that.operands_); dbg_line_insts_ = std::move(that.dbg_line_insts_); @@ -110,8 +109,8 @@ Instruction& Instruction::operator=(Instruction&& that) { Instruction* Instruction::Clone(IRContext* c) const { Instruction* clone = new Instruction(c); clone->opcode_ = opcode_; - clone->type_id_ = type_id_; - clone->result_id_ = result_id_; + clone->has_type_id_ = has_type_id_; + clone->has_result_id_ = has_result_id_; clone->unique_id_ = c->TakeNextUniqueId(); clone->operands_ = operands_; clone->dbg_line_insts_ = dbg_line_insts_; @@ -139,15 +138,14 @@ void Instruction::ToBinaryWithoutAttachedDebugInsts( binary->insert(binary->end(), operand.words.begin(), operand.words.end()); } -void Instruction::ReplaceOperands(const std::vector& new_operands) { +void Instruction::ReplaceOperands(const OperandList& new_operands) { operands_.clear(); operands_.insert(operands_.begin(), new_operands.begin(), new_operands.end()); - operands_.shrink_to_fit(); } bool Instruction::IsReadOnlyLoad() const { if (IsLoad()) { - ir::Instruction* address_def = GetBaseAddress(); + Instruction* address_def = GetBaseAddress(); if (!address_def || address_def->opcode() != SpvOpVariable) { return false; } @@ -163,7 +161,7 @@ Instruction* Instruction::GetBaseAddress() const { "GetBaseAddress should only be called on instructions that take a " "pointer or image."); uint32_t base = GetSingleWordInOperand(kLoadBaseIndex); - ir::Instruction* base_inst = context()->get_def_use_mgr()->GetDef(base); + Instruction* base_inst = context()->get_def_use_mgr()->GetDef(base); bool done = false; while (!done) { switch (base_inst->opcode()) { @@ -219,7 +217,7 @@ bool Instruction::IsVulkanStorageImage() const { return false; } - ir::Instruction* base_type = + Instruction* base_type = context()->get_def_use_mgr()->GetDef(GetSingleWordInOperand(1)); if (base_type->opcode() != SpvOpTypeImage) { return false; @@ -245,7 +243,7 @@ bool Instruction::IsVulkanSampledImage() const { return false; } - ir::Instruction* base_type = + Instruction* base_type = context()->get_def_use_mgr()->GetDef(GetSingleWordInOperand(1)); if (base_type->opcode() != SpvOpTypeImage) { return false; @@ -271,7 +269,7 @@ bool Instruction::IsVulkanStorageTexelBuffer() const { return false; } - ir::Instruction* base_type = + Instruction* base_type = context()->get_def_use_mgr()->GetDef(GetSingleWordInOperand(1)); if (base_type->opcode() != SpvOpTypeImage) { return false; @@ -293,7 +291,7 @@ bool Instruction::IsVulkanStorageBuffer() const { return false; } - ir::Instruction* base_type = + Instruction* base_type = context()->get_def_use_mgr()->GetDef(GetSingleWordInOperand(1)); if (base_type->opcode() != SpvOpTypeStruct) { @@ -305,13 +303,13 @@ bool Instruction::IsVulkanStorageBuffer() const { bool is_buffer_block = false; context()->get_decoration_mgr()->ForEachDecoration( base_type->result_id(), SpvDecorationBufferBlock, - [&is_buffer_block](const ir::Instruction&) { is_buffer_block = true; }); + [&is_buffer_block](const Instruction&) { is_buffer_block = true; }); return is_buffer_block; } else if (storage_class == SpvStorageClassStorageBuffer) { bool is_block = false; context()->get_decoration_mgr()->ForEachDecoration( base_type->result_id(), SpvDecorationBlock, - [&is_block](const ir::Instruction&) { is_block = true; }); + [&is_block](const Instruction&) { is_block = true; }); return is_block; } return false; @@ -327,7 +325,7 @@ bool Instruction::IsVulkanUniformBuffer() const { return false; } - ir::Instruction* base_type = + Instruction* base_type = context()->get_def_use_mgr()->GetDef(GetSingleWordInOperand(1)); if (base_type->opcode() != SpvOpTypeStruct) { return false; @@ -336,7 +334,7 @@ bool Instruction::IsVulkanUniformBuffer() const { bool is_block = false; context()->get_decoration_mgr()->ForEachDecoration( base_type->result_id(), SpvDecorationBlock, - [&is_block](const ir::Instruction&) { is_block = true; }); + [&is_block](const Instruction&) { is_block = true; }); return is_block; } @@ -416,12 +414,13 @@ bool Instruction::IsValidBasePointer() const { return false; } - ir::Instruction* type = context()->get_def_use_mgr()->GetDef(tid); + Instruction* type = context()->get_def_use_mgr()->GetDef(tid); if (type->opcode() != SpvOpTypePointer) { return false; } - if (context()->get_feature_mgr()->HasCapability(SpvCapabilityAddresses)) { + auto feature_mgr = context()->get_feature_mgr(); + if (feature_mgr->HasCapability(SpvCapabilityAddresses)) { // TODO: The rules here could be more restrictive. return true; } @@ -430,8 +429,27 @@ bool Instruction::IsValidBasePointer() const { return true; } + // With variable pointers, there are more valid base pointer objects. + // Variable pointers implicitly declares Variable pointers storage buffer. + SpvStorageClass storage_class = + static_cast(type->GetSingleWordInOperand(0)); + if ((feature_mgr->HasCapability(SpvCapabilityVariablePointersStorageBuffer) && + storage_class == SpvStorageClassStorageBuffer) || + (feature_mgr->HasCapability(SpvCapabilityVariablePointers) && + storage_class == SpvStorageClassWorkgroup)) { + switch (opcode()) { + case SpvOpPhi: + case SpvOpSelect: + case SpvOpFunctionCall: + case SpvOpConstantNull: + return true; + default: + break; + } + } + uint32_t pointee_type_id = type->GetSingleWordInOperand(1); - ir::Instruction* pointee_type_inst = + Instruction* pointee_type_inst = context()->get_def_use_mgr()->GetDef(pointee_type_id); if (pointee_type_inst->IsOpaqueType()) { @@ -446,7 +464,7 @@ bool Instruction::IsValidBaseImage() const { return false; } - ir::Instruction* type = context()->get_def_use_mgr()->GetDef(tid); + Instruction* type = context()->get_def_use_mgr()->GetDef(tid); return (type->opcode() == SpvOpTypeImage || type->opcode() == SpvOpTypeSampledImage); } @@ -455,13 +473,13 @@ bool Instruction::IsOpaqueType() const { if (opcode() == SpvOpTypeStruct) { bool is_opaque = false; ForEachInOperand([&is_opaque, this](const uint32_t* op_id) { - ir::Instruction* type_inst = context()->get_def_use_mgr()->GetDef(*op_id); + Instruction* type_inst = context()->get_def_use_mgr()->GetDef(*op_id); is_opaque |= type_inst->IsOpaqueType(); }); return is_opaque; } else if (opcode() == SpvOpTypeArray) { uint32_t sub_type_id = GetSingleWordInOperand(0); - ir::Instruction* sub_type_inst = + Instruction* sub_type_inst = context()->get_def_use_mgr()->GetDef(sub_type_id); return sub_type_inst->IsOpaqueType(); } else { @@ -472,15 +490,16 @@ bool Instruction::IsOpaqueType() const { bool Instruction::IsFoldable() const { return IsFoldableByFoldScalar() || - opt::GetConstantFoldingRules().HasFoldingRule(opcode()); + context()->get_instruction_folder().HasConstFoldingRule(opcode()); } bool Instruction::IsFoldableByFoldScalar() const { - if (!opt::IsFoldableOpcode(opcode())) { + const InstructionFolder& folder = context()->get_instruction_folder(); + if (!folder.IsFoldableOpcode(opcode())) { return false; } Instruction* type = context()->get_def_use_mgr()->GetDef(type_id()); - return opt::IsFoldableType(type); + return folder.IsFoldableType(type); } bool Instruction::IsFloatingPointFoldingAllowed() const { @@ -492,7 +511,7 @@ bool Instruction::IsFloatingPointFoldingAllowed() const { bool is_nocontract = false; context_->get_decoration_mgr()->WhileEachDecoration( opcode_, SpvDecorationNoContraction, - [&is_nocontract](const ir::Instruction&) { + [&is_nocontract](const Instruction&) { is_nocontract = true; return false; }); @@ -516,16 +535,27 @@ std::string Instruction::PrettyPrint(uint32_t options) const { options | SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); } -std::ostream& operator<<(std::ostream& str, const ir::Instruction& inst) { +std::ostream& operator<<(std::ostream& str, const Instruction& inst) { str << inst.PrettyPrint(); return str; } bool Instruction::IsOpcodeCodeMotionSafe() const { switch (opcode_) { + case SpvOpNop: + case SpvOpUndef: + case SpvOpLoad: + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpArrayLength: case SpvOpVectorExtractDynamic: case SpvOpVectorInsertDynamic: case SpvOpVectorShuffle: + case SpvOpCompositeConstruct: + case SpvOpCompositeExtract: + case SpvOpCompositeInsert: + case SpvOpCopyObject: + case SpvOpTranspose: case SpvOpConvertFToU: case SpvOpConvertFToS: case SpvOpConvertSToF: @@ -556,11 +586,22 @@ bool Instruction::IsOpcodeCodeMotionSafe() const { case SpvOpVectorTimesMatrix: case SpvOpMatrixTimesVector: case SpvOpMatrixTimesMatrix: + case SpvOpOuterProduct: + case SpvOpDot: + case SpvOpIAddCarry: + case SpvOpISubBorrow: + case SpvOpUMulExtended: + case SpvOpSMulExtended: + case SpvOpAny: + case SpvOpAll: + case SpvOpIsNan: + case SpvOpIsInf: case SpvOpLogicalEqual: case SpvOpLogicalNotEqual: case SpvOpLogicalOr: case SpvOpLogicalAnd: case SpvOpLogicalNot: + case SpvOpSelect: case SpvOpIEqual: case SpvOpINotEqual: case SpvOpUGreaterThan: @@ -590,11 +631,116 @@ bool Instruction::IsOpcodeCodeMotionSafe() const { case SpvOpBitwiseXor: case SpvOpBitwiseAnd: case SpvOpNot: + case SpvOpBitFieldInsert: + case SpvOpBitFieldSExtract: + case SpvOpBitFieldUExtract: + case SpvOpBitReverse: + case SpvOpBitCount: + case SpvOpSizeOf: return true; default: return false; } } -} // namespace ir +bool Instruction::IsScalarizable() const { + if (spvOpcodeIsScalarizable(opcode())) { + return true; + } + + const uint32_t kExtInstSetIdInIdx = 0; + const uint32_t kExtInstInstructionInIdx = 1; + + if (opcode() == SpvOpExtInst) { + uint32_t instSetId = + context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); + + if (GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId) { + switch (GetSingleWordInOperand(kExtInstInstructionInIdx)) { + case GLSLstd450Round: + case GLSLstd450RoundEven: + case GLSLstd450Trunc: + case GLSLstd450FAbs: + case GLSLstd450SAbs: + case GLSLstd450FSign: + case GLSLstd450SSign: + case GLSLstd450Floor: + case GLSLstd450Ceil: + case GLSLstd450Fract: + case GLSLstd450Radians: + case GLSLstd450Degrees: + case GLSLstd450Sin: + case GLSLstd450Cos: + case GLSLstd450Tan: + case GLSLstd450Asin: + case GLSLstd450Acos: + case GLSLstd450Atan: + case GLSLstd450Sinh: + case GLSLstd450Cosh: + case GLSLstd450Tanh: + case GLSLstd450Asinh: + case GLSLstd450Acosh: + case GLSLstd450Atanh: + case GLSLstd450Atan2: + case GLSLstd450Pow: + case GLSLstd450Exp: + case GLSLstd450Log: + case GLSLstd450Exp2: + case GLSLstd450Log2: + case GLSLstd450Sqrt: + case GLSLstd450InverseSqrt: + case GLSLstd450Modf: + case GLSLstd450FMin: + case GLSLstd450UMin: + case GLSLstd450SMin: + case GLSLstd450FMax: + case GLSLstd450UMax: + case GLSLstd450SMax: + case GLSLstd450FClamp: + case GLSLstd450UClamp: + case GLSLstd450SClamp: + case GLSLstd450FMix: + case GLSLstd450Step: + case GLSLstd450SmoothStep: + case GLSLstd450Fma: + case GLSLstd450Frexp: + case GLSLstd450Ldexp: + case GLSLstd450FindILsb: + case GLSLstd450FindSMsb: + case GLSLstd450FindUMsb: + case GLSLstd450NMin: + case GLSLstd450NMax: + case GLSLstd450NClamp: + return true; + default: + return false; + } + } + } + return false; +} + +bool Instruction::IsOpcodeSafeToDelete() const { + if (context()->IsCombinatorInstruction(this)) { + return true; + } + + switch (opcode()) { + case SpvOpDPdx: + case SpvOpDPdy: + case SpvOpFwidth: + case SpvOpDPdxFine: + case SpvOpDPdyFine: + case SpvOpFwidthFine: + case SpvOpDPdxCoarse: + case SpvOpDPdyCoarse: + case SpvOpFwidthCoarse: + case SpvOpImageQueryLod: + return true; + default: + return false; + } +} + +} // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/instruction.h b/3rdparty/spirv-tools/source/opt/instruction.h index e74631057..2533ba272 100644 --- a/3rdparty/spirv-tools/source/opt/instruction.h +++ b/3rdparty/spirv-tools/source/opt/instruction.h @@ -12,31 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_INSTRUCTION_H_ -#define LIBSPIRV_OPT_INSTRUCTION_H_ +#ifndef SOURCE_OPT_INSTRUCTION_H_ +#define SOURCE_OPT_INSTRUCTION_H_ #include #include +#include +#include #include #include -#include "opcode.h" -#include "operand.h" -#include "util/ilist_node.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/util/ilist_node.h" +#include "source/util/small_vector.h" -#include "latest_version_spirv_header.h" -#include "reflect.h" +#include "source/latest_version_glsl_std_450_header.h" +#include "source/latest_version_spirv_header.h" +#include "source/opt/reflect.h" #include "spirv-tools/libspirv.h" namespace spvtools { -namespace ir { +namespace opt { class Function; class IRContext; class Module; class InstructionList; -// Relaxed logcial addressing: +// Relaxed logical addressing: // // In the logical addressing model, pointers cannot be stored or loaded. This // is a useful assumption because it simplifies the aliasing significantly. @@ -67,14 +71,14 @@ class InstructionList; // A *logical* operand to a SPIR-V instruction. It can be the type id, result // id, or other additional operands carried in an instruction. struct Operand { - Operand(spv_operand_type_t t, std::vector&& w) + using OperandData = utils::SmallVector; + Operand(spv_operand_type_t t, OperandData&& w) : type(t), words(std::move(w)) {} - Operand(spv_operand_type_t t, const std::vector& w) - : type(t), words(w) {} + Operand(spv_operand_type_t t, const OperandData& w) : type(t), words(w) {} - spv_operand_type_t type; // Type of this logical operand. - std::vector words; // Binary segments of this logical operand. + spv_operand_type_t type; // Type of this logical operand. + OperandData words; // Binary segments of this logical operand. friend bool operator==(const Operand& o1, const Operand& o2) { return o1.type == o2.type && o1.words == o2.words; @@ -95,8 +99,9 @@ inline bool operator!=(const Operand& o1, const Operand& o2) { // needs to change, the user should create a new instruction instead. class Instruction : public utils::IntrusiveNodeBase { public: - using iterator = std::vector::iterator; - using const_iterator = std::vector::const_iterator; + using OperandList = std::vector; + using iterator = OperandList::iterator; + using const_iterator = OperandList::const_iterator; // Creates a default OpNop instruction. // This exists solely for containers that can't do without. Should be removed. @@ -104,8 +109,8 @@ class Instruction : public utils::IntrusiveNodeBase { : utils::IntrusiveNodeBase(), context_(nullptr), opcode_(SpvOpNop), - type_id_(0), - result_id_(0), + has_type_id_(false), + has_result_id_(false), unique_id_(0) {} // Creates a default OpNop instruction. @@ -123,7 +128,7 @@ class Instruction : public utils::IntrusiveNodeBase { // Creates an instruction with the given opcode |op|, type id: |ty_id|, // result id: |res_id| and input operands: |in_operands|. Instruction(IRContext* c, SpvOp op, uint32_t ty_id, uint32_t res_id, - const std::vector& in_operands); + const OperandList& in_operands); // TODO: I will want to remove these, but will first have to remove the use of // std::vector. @@ -150,8 +155,12 @@ class Instruction : public utils::IntrusiveNodeBase { // TODO(qining): Remove this function when instruction building and insertion // is well implemented. void SetOpcode(SpvOp op) { opcode_ = op; } - uint32_t type_id() const { return type_id_; } - uint32_t result_id() const { return result_id_; } + uint32_t type_id() const { + return has_type_id_ ? GetSingleWordOperand(0) : 0; + } + uint32_t result_id() const { + return has_result_id_ ? GetSingleWordOperand(has_type_id_ ? 1 : 0) : 0; + } uint32_t unique_id() const { assert(unique_id_ != 0); return unique_id_; @@ -197,18 +206,18 @@ class Instruction : public utils::IntrusiveNodeBase { // words. uint32_t GetSingleWordOperand(uint32_t index) const; // Sets the |index|-th in-operand's data to the given |data|. - inline void SetInOperand(uint32_t index, std::vector&& data); + inline void SetInOperand(uint32_t index, Operand::OperandData&& data); // Sets the |index|-th operand's data to the given |data|. // This is for in-operands modification only, but with |index| expressed in // terms of operand index rather than in-operand index. - inline void SetOperand(uint32_t index, std::vector&& data); + inline void SetOperand(uint32_t index, Operand::OperandData&& data); // Replace all of the in operands with those in |new_operands|. - inline void SetInOperands(std::vector&& new_operands); + inline void SetInOperands(OperandList&& new_operands); // Sets the result type id. inline void SetResultType(uint32_t ty_id); // Sets the result id inline void SetResultId(uint32_t res_id); - inline bool HasResultId() const { return result_id_ != 0; } + inline bool HasResultId() const { return has_result_id_; } // Remove the |index|-th operand void RemoveOperand(uint32_t index) { operands_.erase(operands_.begin() + index); @@ -291,7 +300,7 @@ class Instruction : public utils::IntrusiveNodeBase { // Replaces the operands to the instruction with |new_operands|. The caller // is responsible for building a complete and valid list of operands for // this instruction. - void ReplaceOperands(const std::vector& new_operands); + void ReplaceOperands(const OperandList& new_operands); // Returns true if the instruction annotates an id with a decoration. inline bool IsDecoration() const; @@ -405,10 +414,25 @@ class Instruction : public utils::IntrusiveNodeBase { // is always added to |options|. std::string PrettyPrint(uint32_t options = 0u) const; + // Returns true if the result can be a vector and the result of each component + // depends on the corresponding component of any vector inputs. + bool IsScalarizable() const; + + // Return true if the only effect of this instructions is the result. + bool IsOpcodeSafeToDelete() const; + + // Returns true if it is valid to use the result of |inst| as the base + // pointer for a load or store. In this case, valid is defined by the relaxed + // logical addressing rules when using logical addressing. Normal validation + // rules for physical addressing. + bool IsValidBasePointer() const; + private: // Returns the total count of result type id and result id. uint32_t TypeResultIdCount() const { - return (type_id_ != 0) + (result_id_ != 0); + if (has_type_id_ && has_result_id_) return 2; + if (has_type_id_ || has_result_id_) return 1; + return 0; } // Returns true if the instruction declares a variable that is read-only. The @@ -417,23 +441,17 @@ class Instruction : public utils::IntrusiveNodeBase { bool IsReadOnlyVariableShaders() const; bool IsReadOnlyVariableKernel() const; - // Returns true if it is valid to use the result of |inst| as the base - // pointer for a load or store. In this case, valid is defined by the relaxed - // logical addressing rules when using logical addressing. Normal validation - // rules for physical addressing. - bool IsValidBasePointer() const; - // Returns true if the result of |inst| can be used as the base image for an // instruction that samples a image, reads an image, or writes to an image. bool IsValidBaseImage() const; IRContext* context_; // IR Context SpvOp opcode_; // Opcode - uint32_t type_id_; // Result type id. A value of 0 means no result type id. - uint32_t result_id_; // Result id. A value of 0 means no result id. + bool has_type_id_; // True if the instruction has a type id + bool has_result_id_; // True if the instruction has a result id uint32_t unique_id_; // Unique instruction id // All logical operands, including result type id and result id. - std::vector operands_; + OperandList operands_; // Opline and OpNoLine instructions preceding this instruction. Note that for // Instructions representing OpLine or OpNonLine itself, this field should be // empty. @@ -448,7 +466,7 @@ class Instruction : public utils::IntrusiveNodeBase { // to provide the correct interpretation of types, constants, etc. // // Disassembly uses raw ids (not pretty printed names). -std::ostream& operator<<(std::ostream& str, const ir::Instruction& inst); +std::ostream& operator<<(std::ostream& str, const Instruction& inst); inline bool Instruction::operator==(const Instruction& other) const { return unique_id() == other.unique_id(); @@ -477,18 +495,18 @@ inline void Instruction::AddOperand(Operand&& operand) { } inline void Instruction::SetInOperand(uint32_t index, - std::vector&& data) { + Operand::OperandData&& data) { SetOperand(index + TypeResultIdCount(), std::move(data)); } inline void Instruction::SetOperand(uint32_t index, - std::vector&& data) { + Operand::OperandData&& data) { assert(index < operands_.size() && "operand index out of bound"); assert(index >= TypeResultIdCount() && "operand is not a in-operand"); operands_[index].words = std::move(data); } -inline void Instruction::SetInOperands(std::vector&& new_operands) { +inline void Instruction::SetInOperands(OperandList&& new_operands) { // Remove the old in operands. operands_.erase(operands_.begin() + TypeResultIdCount(), operands_.end()); // Add the new in operands. @@ -496,28 +514,43 @@ inline void Instruction::SetInOperands(std::vector&& new_operands) { } inline void Instruction::SetResultId(uint32_t res_id) { - result_id_ = res_id; - auto ridx = (type_id_ != 0) ? 1 : 0; - assert(operands_[ridx].type == SPV_OPERAND_TYPE_RESULT_ID); + // TODO(dsinclair): Allow setting a result id if there wasn't one + // previously. Need to make room in the operands_ array to place the result, + // and update the has_result_id_ flag. + assert(has_result_id_); + + // TODO(dsinclair): Allow removing the result id. This needs to make sure, + // if there was a result id previously to remove it from the operands_ array + // and reset the has_result_id_ flag. + assert(res_id != 0); + + auto ridx = has_type_id_ ? 1 : 0; operands_[ridx].words = {res_id}; } inline void Instruction::SetResultType(uint32_t ty_id) { - if (type_id_ != 0) { - type_id_ = ty_id; - assert(operands_.front().type == SPV_OPERAND_TYPE_TYPE_ID); - operands_.front().words = {ty_id}; - } + // TODO(dsinclair): Allow setting a type id if there wasn't one + // previously. Need to make room in the operands_ array to place the result, + // and update the has_type_id_ flag. + assert(has_type_id_); + + // TODO(dsinclair): Allow removing the type id. This needs to make sure, + // if there was a type id previously to remove it from the operands_ array + // and reset the has_type_id_ flag. + assert(ty_id != 0); + + operands_.front().words = {ty_id}; } inline bool Instruction::IsNop() const { - return opcode_ == SpvOpNop && type_id_ == 0 && result_id_ == 0 && + return opcode_ == SpvOpNop && !has_type_id_ && !has_result_id_ && operands_.empty(); } inline void Instruction::ToNop() { opcode_ = SpvOpNop; - type_id_ = result_id_ = 0; + has_type_id_ = false; + has_result_id_ = false; operands_.clear(); } @@ -566,9 +599,6 @@ inline void Instruction::ForEachInst( inline void Instruction::ForEachId(const std::function& f) { for (auto& opnd : operands_) if (spvIsIdType(opnd.type)) f(&opnd.words[0]); - if (type_id_ != 0u) type_id_ = GetSingleWordOperand(0u); - if (result_id_ != 0u) - result_id_ = GetSingleWordOperand(type_id_ == 0u ? 0u : 1u); } inline void Instruction::ForEachId( @@ -699,7 +729,7 @@ bool Instruction::IsAtomicOp() const { return spvOpcodeIsAtomicOp(opcode()); } bool Instruction::IsConstant() const { return IsCompileTimeConstantInst(opcode()); } -} // namespace ir +} // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_INSTRUCTION_H_ +#endif // SOURCE_OPT_INSTRUCTION_H_ diff --git a/3rdparty/spirv-tools/source/opt/instruction_list.cpp b/3rdparty/spirv-tools/source/opt/instruction_list.cpp index d8ddb84e5..385a136ec 100644 --- a/3rdparty/spirv-tools/source/opt/instruction_list.cpp +++ b/3rdparty/spirv-tools/source/opt/instruction_list.cpp @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "instruction_list.h" +#include "source/opt/instruction_list.h" namespace spvtools { -namespace ir { +namespace opt { InstructionList::iterator InstructionList::iterator::InsertBefore( std::vector>&& list) { @@ -32,5 +32,5 @@ InstructionList::iterator InstructionList::iterator::InsertBefore( i.get()->InsertBefore(node_); return iterator(i.release()); } -} // namespace ir -} // namespace spvtools \ No newline at end of file +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/instruction_list.h b/3rdparty/spirv-tools/source/opt/instruction_list.h index 182317fb4..ea1cc7c46 100644 --- a/3rdparty/spirv-tools/source/opt/instruction_list.h +++ b/3rdparty/spirv-tools/source/opt/instruction_list.h @@ -13,23 +13,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_INSTRUCTION_LIST_H_ -#define LIBSPIRV_OPT_INSTRUCTION_LIST_H_ +#ifndef SOURCE_OPT_INSTRUCTION_LIST_H_ +#define SOURCE_OPT_INSTRUCTION_LIST_H_ #include #include +#include #include #include -#include "instruction.h" -#include "operand.h" -#include "util/ilist.h" - -#include "latest_version_spirv_header.h" +#include "source/latest_version_spirv_header.h" +#include "source/operand.h" +#include "source/opt/instruction.h" +#include "source/util/ilist.h" #include "spirv-tools/libspirv.h" namespace spvtools { -namespace ir { +namespace opt { // This class is intended to be the container for Instructions. This container // owns the instructions that are in it. When removing an Instruction from the @@ -124,7 +124,7 @@ void InstructionList::clear() { } } -} // namespace ir +} // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_INSTRUCTION_LIST_H_ +#endif // SOURCE_OPT_INSTRUCTION_LIST_H_ diff --git a/3rdparty/spirv-tools/source/opt/ir_builder.h b/3rdparty/spirv-tools/source/opt/ir_builder.h index aba6ef360..2dab76e52 100644 --- a/3rdparty/spirv-tools/source/opt/ir_builder.h +++ b/3rdparty/spirv-tools/source/opt/ir_builder.h @@ -12,13 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_IR_BUILDER_H_ -#define LIBSPIRV_OPT_IR_BUILDER_H_ +#ifndef SOURCE_OPT_IR_BUILDER_H_ +#define SOURCE_OPT_IR_BUILDER_H_ + +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/constants.h" +#include "source/opt/instruction.h" +#include "source/opt/ir_context.h" -#include "opt/basic_block.h" -#include "opt/constants.h" -#include "opt/instruction.h" -#include "opt/ir_context.h" namespace spvtools { namespace opt { @@ -33,13 +39,13 @@ const uint32_t kInvalidId = std::numeric_limits::max(); // - Instruction to block analysis class InstructionBuilder { public: - using InsertionPointTy = spvtools::ir::BasicBlock::iterator; + using InsertionPointTy = BasicBlock::iterator; // Creates an InstructionBuilder, all new instructions will be inserted before // the instruction |insert_before|. InstructionBuilder( - ir::IRContext* context, ir::Instruction* insert_before, - ir::IRContext::Analysis preserved_analyses = ir::IRContext::kAnalysisNone) + IRContext* context, Instruction* insert_before, + IRContext::Analysis preserved_analyses = IRContext::kAnalysisNone) : InstructionBuilder(context, context->get_instr_block(insert_before), InsertionPointTy(insert_before), preserved_analyses) {} @@ -47,17 +53,17 @@ class InstructionBuilder { // Creates an InstructionBuilder, all new instructions will be inserted at the // end of the basic block |parent_block|. InstructionBuilder( - ir::IRContext* context, ir::BasicBlock* parent_block, - ir::IRContext::Analysis preserved_analyses = ir::IRContext::kAnalysisNone) + IRContext* context, BasicBlock* parent_block, + IRContext::Analysis preserved_analyses = IRContext::kAnalysisNone) : InstructionBuilder(context, parent_block, parent_block->end(), preserved_analyses) {} // Creates a new selection merge instruction. // The id |merge_id| is the merge basic block id. - ir::Instruction* AddSelectionMerge( + Instruction* AddSelectionMerge( uint32_t merge_id, uint32_t selection_control = SpvSelectionControlMaskNone) { - std::unique_ptr new_branch_merge(new ir::Instruction( + std::unique_ptr new_branch_merge(new Instruction( GetContext(), SpvOpSelectionMerge, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {merge_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_SELECTION_CONTROL, @@ -68,8 +74,8 @@ class InstructionBuilder { // Creates a new branch instruction to |label_id|. // Note that the user must make sure the final basic block is // well formed. - ir::Instruction* AddBranch(uint32_t label_id) { - std::unique_ptr new_branch(new ir::Instruction( + Instruction* AddBranch(uint32_t label_id) { + std::unique_ptr new_branch(new Instruction( GetContext(), SpvOpBranch, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {label_id}}})); return AddInstruction(std::move(new_branch)); @@ -90,14 +96,14 @@ class InstructionBuilder { // selection merge instruction. // Note that the user must make sure the final basic block is // well formed. - ir::Instruction* AddConditionalBranch( + Instruction* AddConditionalBranch( uint32_t cond_id, uint32_t true_id, uint32_t false_id, uint32_t merge_id = kInvalidId, uint32_t selection_control = SpvSelectionControlMaskNone) { if (merge_id != kInvalidId) { AddSelectionMerge(merge_id, selection_control); } - std::unique_ptr new_branch(new ir::Instruction( + std::unique_ptr new_branch(new Instruction( GetContext(), SpvOpBranchConditional, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cond_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {true_id}}, @@ -118,28 +124,28 @@ class InstructionBuilder { // selection merge instruction. // Note that the user must make sure the final basic block is // well formed. - ir::Instruction* AddSwitch( + Instruction* AddSwitch( uint32_t selector_id, uint32_t default_id, - const std::vector, uint32_t>>& targets, + const std::vector>& targets, uint32_t merge_id = kInvalidId, uint32_t selection_control = SpvSelectionControlMaskNone) { if (merge_id != kInvalidId) { AddSelectionMerge(merge_id, selection_control); } - std::vector operands; + std::vector operands; operands.emplace_back( - ir::Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {selector_id}}); + Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {selector_id}}); operands.emplace_back( - ir::Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {default_id}}); + Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {default_id}}); for (auto& target : targets) { operands.emplace_back( - ir::Operand{spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, - target.first}); - operands.emplace_back(ir::Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, - {target.second}}); + Operand{spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, + target.first}); + operands.emplace_back( + Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {target.second}}); } - std::unique_ptr new_switch( - new ir::Instruction(GetContext(), SpvOpSwitch, 0, 0, operands)); + std::unique_ptr new_switch( + new Instruction(GetContext(), SpvOpSwitch, 0, 0, operands)); return AddInstruction(std::move(new_switch)); } @@ -147,14 +153,13 @@ class InstructionBuilder { // The id |type| must be the id of the phi instruction's type. // The vector |incomings| must be a sequence of pairs of . - ir::Instruction* AddPhi(uint32_t type, - const std::vector& incomings) { + Instruction* AddPhi(uint32_t type, const std::vector& incomings) { assert(incomings.size() % 2 == 0 && "A sequence of pairs is expected"); - std::vector phi_ops; + std::vector phi_ops; for (size_t i = 0; i < incomings.size(); i++) { phi_ops.push_back({SPV_OPERAND_TYPE_ID, {incomings[i]}}); } - std::unique_ptr phi_inst(new ir::Instruction( + std::unique_ptr phi_inst(new Instruction( GetContext(), SpvOpPhi, type, GetContext()->TakeNextId(), phi_ops)); return AddInstruction(std::move(phi_inst)); } @@ -164,8 +169,8 @@ class InstructionBuilder { // |op1| and |op2| types. // The id |op1| is the left hand side of the operation. // The id |op2| is the right hand side of the operation. - ir::Instruction* AddIAdd(uint32_t type, uint32_t op1, uint32_t op2) { - std::unique_ptr inst(new ir::Instruction( + Instruction* AddIAdd(uint32_t type, uint32_t op1, uint32_t op2) { + std::unique_ptr inst(new Instruction( GetContext(), SpvOpIAdd, type, GetContext()->TakeNextId(), {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}})); return AddInstruction(std::move(inst)); @@ -175,10 +180,10 @@ class InstructionBuilder { // The id |op1| is the left hand side of the operation. // The id |op2| is the right hand side of the operation. // It is assumed that |op1| and |op2| have the same underlying type. - ir::Instruction* AddULessThan(uint32_t op1, uint32_t op2) { + Instruction* AddULessThan(uint32_t op1, uint32_t op2) { analysis::Bool bool_type; uint32_t type = GetContext()->get_type_mgr()->GetId(&bool_type); - std::unique_ptr inst(new ir::Instruction( + std::unique_ptr inst(new Instruction( GetContext(), SpvOpULessThan, type, GetContext()->TakeNextId(), {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}})); return AddInstruction(std::move(inst)); @@ -188,10 +193,10 @@ class InstructionBuilder { // The id |op1| is the left hand side of the operation. // The id |op2| is the right hand side of the operation. // It is assumed that |op1| and |op2| have the same underlying type. - ir::Instruction* AddSLessThan(uint32_t op1, uint32_t op2) { + Instruction* AddSLessThan(uint32_t op1, uint32_t op2) { analysis::Bool bool_type; uint32_t type = GetContext()->get_type_mgr()->GetId(&bool_type); - std::unique_ptr inst(new ir::Instruction( + std::unique_ptr inst(new Instruction( GetContext(), SpvOpSLessThan, type, GetContext()->TakeNextId(), {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}})); return AddInstruction(std::move(inst)); @@ -201,8 +206,8 @@ class InstructionBuilder { // |op1|. The id |op1| is the left hand side of the operation. The id |op2| is // the right hand side of the operation. It is assumed that |op1| and |op2| // have the same underlying type. - ir::Instruction* AddLessThan(uint32_t op1, uint32_t op2) { - ir::Instruction* op1_insn = context_->get_def_use_mgr()->GetDef(op1); + Instruction* AddLessThan(uint32_t op1, uint32_t op2) { + Instruction* op1_insn = context_->get_def_use_mgr()->GetDef(op1); analysis::Type* type = GetContext()->get_type_mgr()->GetType(op1_insn->type_id()); analysis::Integer* int_type = type->AsInteger(); @@ -218,41 +223,40 @@ class InstructionBuilder { // |type| must match the types of |true_value| and |false_value|. It is up to // the caller to ensure that |cond| is a correct type (bool or vector of // bool) for |type|. - ir::Instruction* AddSelect(uint32_t type, uint32_t cond, uint32_t true_value, - uint32_t false_value) { - std::unique_ptr select(new ir::Instruction( + Instruction* AddSelect(uint32_t type, uint32_t cond, uint32_t true_value, + uint32_t false_value) { + std::unique_ptr select(new Instruction( GetContext(), SpvOpSelect, type, GetContext()->TakeNextId(), - std::initializer_list{ - {SPV_OPERAND_TYPE_ID, {cond}}, - {SPV_OPERAND_TYPE_ID, {true_value}}, - {SPV_OPERAND_TYPE_ID, {false_value}}})); + std::initializer_list{{SPV_OPERAND_TYPE_ID, {cond}}, + {SPV_OPERAND_TYPE_ID, {true_value}}, + {SPV_OPERAND_TYPE_ID, {false_value}}})); return AddInstruction(std::move(select)); } // Adds a signed int32 constant to the binary. // The |value| parameter is the constant value to be added. - ir::Instruction* Add32BitSignedIntegerConstant(int32_t value) { + Instruction* Add32BitSignedIntegerConstant(int32_t value) { return Add32BitConstantInteger(value, true); } // Create a composite construct. // |type| should be a composite type and the number of elements it has should // match the size od |ids|. - ir::Instruction* AddCompositeConstruct(uint32_t type, - const std::vector& ids) { - std::vector ops; + Instruction* AddCompositeConstruct(uint32_t type, + const std::vector& ids) { + std::vector ops; for (auto id : ids) { ops.emplace_back(SPV_OPERAND_TYPE_ID, std::initializer_list{id}); } - std::unique_ptr construct( - new ir::Instruction(GetContext(), SpvOpCompositeConstruct, type, - GetContext()->TakeNextId(), ops)); + std::unique_ptr construct( + new Instruction(GetContext(), SpvOpCompositeConstruct, type, + GetContext()->TakeNextId(), ops)); return AddInstruction(std::move(construct)); } // Adds an unsigned int32 constant to the binary. // The |value| parameter is the constant value to be added. - ir::Instruction* Add32BitUnsignedIntegerConstant(uint32_t value) { + Instruction* Add32BitUnsignedIntegerConstant(uint32_t value) { return Add32BitConstantInteger(value, false); } @@ -261,7 +265,7 @@ class InstructionBuilder { // signed constant otherwise as an unsigned constant. If |sign| is false the // value must not be a negative number. template - ir::Instruction* Add32BitConstantInteger(T value, bool sign) { + Instruction* Add32BitConstantInteger(T value, bool sign) { // Assert that we are not trying to store a negative number in an unsigned // type. if (!sign) @@ -285,55 +289,64 @@ class InstructionBuilder { uint32_t word = value; // Create the constant value. - const opt::analysis::Constant* constant = + const analysis::Constant* constant = GetContext()->get_constant_mgr()->GetConstant(rebuilt_type, {word}); // Create the OpConstant instruction using the type and the value. return GetContext()->get_constant_mgr()->GetDefiningInstruction(constant); } - ir::Instruction* AddCompositeExtract( - uint32_t type, uint32_t id_of_composite, - const std::vector& index_list) { - std::vector operands; + Instruction* AddCompositeExtract(uint32_t type, uint32_t id_of_composite, + const std::vector& index_list) { + std::vector operands; operands.push_back({SPV_OPERAND_TYPE_ID, {id_of_composite}}); for (uint32_t index : index_list) { operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}); } - std::unique_ptr new_inst( - new ir::Instruction(GetContext(), SpvOpCompositeExtract, type, - GetContext()->TakeNextId(), operands)); + std::unique_ptr new_inst( + new Instruction(GetContext(), SpvOpCompositeExtract, type, + GetContext()->TakeNextId(), operands)); return AddInstruction(std::move(new_inst)); } // Creates an unreachable instruction. - ir::Instruction* AddUnreachable() { - std::unique_ptr select( - new ir::Instruction(GetContext(), SpvOpUnreachable, 0, 0, - std::initializer_list{})); + Instruction* AddUnreachable() { + std::unique_ptr select( + new Instruction(GetContext(), SpvOpUnreachable, 0, 0, + std::initializer_list{})); return AddInstruction(std::move(select)); } - ir::Instruction* AddAccessChain(uint32_t type_id, uint32_t base_ptr_id, - std::vector ids) { - std::vector operands; + Instruction* AddAccessChain(uint32_t type_id, uint32_t base_ptr_id, + std::vector ids) { + std::vector operands; operands.push_back({SPV_OPERAND_TYPE_ID, {base_ptr_id}}); for (uint32_t index_id : ids) { operands.push_back({SPV_OPERAND_TYPE_ID, {index_id}}); } - std::unique_ptr new_inst( - new ir::Instruction(GetContext(), SpvOpAccessChain, type_id, - GetContext()->TakeNextId(), operands)); + std::unique_ptr new_inst( + new Instruction(GetContext(), SpvOpAccessChain, type_id, + GetContext()->TakeNextId(), operands)); + return AddInstruction(std::move(new_inst)); + } + + Instruction* AddLoad(uint32_t type_id, uint32_t base_ptr_id) { + std::vector operands; + operands.push_back({SPV_OPERAND_TYPE_ID, {base_ptr_id}}); + + std::unique_ptr new_inst( + new Instruction(GetContext(), SpvOpLoad, type_id, + GetContext()->TakeNextId(), operands)); return AddInstruction(std::move(new_inst)); } // Inserts the new instruction before the insertion point. - ir::Instruction* AddInstruction(std::unique_ptr&& insn) { - ir::Instruction* insn_ptr = &*insert_before_.InsertBefore(std::move(insn)); + Instruction* AddInstruction(std::unique_ptr&& insn) { + Instruction* insn_ptr = &*insert_before_.InsertBefore(std::move(insn)); UpdateInstrToBlockMapping(insn_ptr); UpdateDefUseMgr(insn_ptr); return insn_ptr; @@ -344,68 +357,65 @@ class InstructionBuilder { // Change the insertion point to insert before the instruction // |insert_before|. - void SetInsertPoint(ir::Instruction* insert_before) { + void SetInsertPoint(Instruction* insert_before) { parent_ = context_->get_instr_block(insert_before); insert_before_ = InsertionPointTy(insert_before); } // Change the insertion point to insert at the end of the basic block // |parent_block|. - void SetInsertPoint(ir::BasicBlock* parent_block) { + void SetInsertPoint(BasicBlock* parent_block) { parent_ = parent_block; insert_before_ = parent_block->end(); } // Returns the context which instructions are constructed for. - ir::IRContext* GetContext() const { return context_; } + IRContext* GetContext() const { return context_; } // Returns the set of preserved analyses. - inline ir::IRContext::Analysis GetPreservedAnalysis() const { + inline IRContext::Analysis GetPreservedAnalysis() const { return preserved_analyses_; } private: - InstructionBuilder(ir::IRContext* context, ir::BasicBlock* parent, + InstructionBuilder(IRContext* context, BasicBlock* parent, InsertionPointTy insert_before, - ir::IRContext::Analysis preserved_analyses) + IRContext::Analysis preserved_analyses) : context_(context), parent_(parent), insert_before_(insert_before), preserved_analyses_(preserved_analyses) { - assert(!(preserved_analyses_ & - ~(ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping))); + assert(!(preserved_analyses_ & ~(IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping))); } // Returns true if the users requested to update |analysis|. - inline bool IsAnalysisUpdateRequested( - ir::IRContext::Analysis analysis) const { + inline bool IsAnalysisUpdateRequested(IRContext::Analysis analysis) const { return preserved_analyses_ & analysis; } // Updates the def/use manager if the user requested it. If he did not request // an update, this function does nothing. - inline void UpdateDefUseMgr(ir::Instruction* insn) { - if (IsAnalysisUpdateRequested(ir::IRContext::kAnalysisDefUse)) + inline void UpdateDefUseMgr(Instruction* insn) { + if (IsAnalysisUpdateRequested(IRContext::kAnalysisDefUse)) GetContext()->get_def_use_mgr()->AnalyzeInstDefUse(insn); } // Updates the instruction to block analysis if the user requested it. If he // did not request an update, this function does nothing. - inline void UpdateInstrToBlockMapping(ir::Instruction* insn) { - if (IsAnalysisUpdateRequested( - ir::IRContext::kAnalysisInstrToBlockMapping) && + inline void UpdateInstrToBlockMapping(Instruction* insn) { + if (IsAnalysisUpdateRequested(IRContext::kAnalysisInstrToBlockMapping) && parent_) GetContext()->set_instr_block(insn, parent_); } - ir::IRContext* context_; - ir::BasicBlock* parent_; + IRContext* context_; + BasicBlock* parent_; InsertionPointTy insert_before_; - const ir::IRContext::Analysis preserved_analyses_; + const IRContext::Analysis preserved_analyses_; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_IR_BUILDER_H_ +#endif // SOURCE_OPT_IR_BUILDER_H_ diff --git a/3rdparty/spirv-tools/source/opt/ir_context.cpp b/3rdparty/spirv-tools/source/opt/ir_context.cpp index 856e40367..742ac1f62 100644 --- a/3rdparty/spirv-tools/source/opt/ir_context.cpp +++ b/3rdparty/spirv-tools/source/opt/ir_context.cpp @@ -12,16 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "ir_context.h" -#include "latest_version_glsl_std_450_header.h" -#include "log.h" -#include "mem_pass.h" -#include "reflect.h" +#include "source/opt/ir_context.h" #include +#include "source/latest_version_glsl_std_450_header.h" +#include "source/opt/log.h" +#include "source/opt/mem_pass.h" +#include "source/opt/reflect.h" + namespace spvtools { -namespace ir { +namespace opt { void IRContext::BuildInvalidAnalyses(IRContext::Analysis set) { if (set & kAnalysisDefUse) { @@ -48,6 +49,12 @@ void IRContext::BuildInvalidAnalyses(IRContext::Analysis set) { if (set & kAnalysisScalarEvolution) { BuildScalarEvolutionAnalysis(); } + if (set & kAnalysisRegisterPressure) { + BuildRegPressureAnalysis(); + } + if (set & kAnalysisValueNumberTable) { + BuildValueNumberTable(); + } } void IRContext::InvalidateAnalysesExceptFor( @@ -79,11 +86,14 @@ void IRContext::InvalidateAnalyses(IRContext::Analysis analyses_to_invalidate) { if (analyses_to_invalidate & kAnalysisNameMap) { id_to_name_.reset(nullptr); } + if (analyses_to_invalidate & kAnalysisValueNumberTable) { + vn_table_.reset(nullptr); + } valid_analyses_ = Analysis(valid_analyses_ & ~analyses_to_invalidate); } -Instruction* IRContext::KillInst(ir::Instruction* inst) { +Instruction* IRContext::KillInst(Instruction* inst) { if (!inst) { return nullptr; } @@ -105,11 +115,11 @@ Instruction* IRContext::KillInst(ir::Instruction* inst) { } } - if (type_mgr_ && ir::IsTypeInst(inst->opcode())) { + if (type_mgr_ && IsTypeInst(inst->opcode())) { type_mgr_->RemoveId(inst->result_id()); } - if (constant_mgr_ && ir::IsConstantInst(inst->opcode())) { + if (constant_mgr_ && IsConstantInst(inst->opcode())) { constant_mgr_->RemoveId(inst->result_id()); } @@ -129,7 +139,7 @@ Instruction* IRContext::KillInst(ir::Instruction* inst) { } bool IRContext::KillDef(uint32_t id) { - ir::Instruction* def = get_def_use_mgr()->GetDef(id); + Instruction* def = get_def_use_mgr()->GetDef(id); if (def != nullptr) { KillInst(def); return true; @@ -144,15 +154,15 @@ bool IRContext::ReplaceAllUsesWith(uint32_t before, uint32_t after) { assert(get_def_use_mgr()->GetDef(after) && "'after' is not a registered def."); - std::vector> uses_to_update; + std::vector> uses_to_update; get_def_use_mgr()->ForEachUse( - before, [&uses_to_update](ir::Instruction* user, uint32_t index) { + before, [&uses_to_update](Instruction* user, uint32_t index) { uses_to_update.emplace_back(user, index); }); - ir::Instruction* prev = nullptr; + Instruction* prev = nullptr; for (auto p : uses_to_update) { - ir::Instruction* user = p.first; + Instruction* user = p.first; uint32_t index = p.second; if (prev == nullptr || prev != user) { ForgetUses(user); @@ -182,7 +192,7 @@ bool IRContext::ReplaceAllUsesWith(uint32_t before, uint32_t after) { user->SetInOperand(in_operand_pos, {after}); } AnalyzeUses(user); - }; + } return true; } @@ -193,7 +203,7 @@ bool IRContext::IsConsistent() { #endif if (AreAnalysesValid(kAnalysisDefUse)) { - opt::analysis::DefUseManager new_def_use(module()); + analysis::DefUseManager new_def_use(module()); if (*get_def_use_mgr() != new_def_use) { return false; } @@ -202,7 +212,7 @@ bool IRContext::IsConsistent() { if (AreAnalysesValid(kAnalysisInstrToBlockMapping)) { for (auto& func : *module()) { for (auto& block : func) { - if (!block.WhileEachInst([this, &block](ir::Instruction* inst) { + if (!block.WhileEachInst([this, &block](Instruction* inst) { if (get_instr_block(inst) != &block) { return false; } @@ -220,7 +230,7 @@ bool IRContext::IsConsistent() { return true; } -void spvtools::ir::IRContext::ForgetUses(Instruction* inst) { +void IRContext::ForgetUses(Instruction* inst) { if (AreAnalysesValid(kAnalysisDefUse)) { get_def_use_mgr()->EraseUseRecordsOfOperandIds(inst); } @@ -248,18 +258,18 @@ void IRContext::AnalyzeUses(Instruction* inst) { } void IRContext::KillNamesAndDecorates(uint32_t id) { - std::vector decorations = + std::vector decorations = get_decoration_mgr()->GetDecorationsFor(id, true); for (Instruction* inst : decorations) { KillInst(inst); } - std::vector name_to_kill; + std::vector name_to_kill; for (auto name : GetNames(id)) { name_to_kill.push_back(name.second); } - for (ir::Instruction* name_inst : name_to_kill) { + for (Instruction* name_inst : name_to_kill) { KillInst(name_inst); } } @@ -272,165 +282,168 @@ void IRContext::KillNamesAndDecorates(Instruction* inst) { void IRContext::AddCombinatorsForCapability(uint32_t capability) { if (capability == SpvCapabilityShader) { - combinator_ops_[0].insert({ - SpvOpNop, - SpvOpUndef, - SpvOpConstant, - SpvOpConstantTrue, - SpvOpConstantFalse, - SpvOpConstantComposite, - SpvOpConstantSampler, - SpvOpConstantNull, - SpvOpTypeVoid, - SpvOpTypeBool, - SpvOpTypeInt, - SpvOpTypeFloat, - SpvOpTypeVector, - SpvOpTypeMatrix, - SpvOpTypeImage, - SpvOpTypeSampler, - SpvOpTypeSampledImage, - SpvOpTypeArray, - SpvOpTypeRuntimeArray, - SpvOpTypeStruct, - SpvOpTypeOpaque, - SpvOpTypePointer, - SpvOpTypeFunction, - SpvOpTypeEvent, - SpvOpTypeDeviceEvent, - SpvOpTypeReserveId, - SpvOpTypeQueue, - SpvOpTypePipe, - SpvOpTypeForwardPointer, - SpvOpVariable, - SpvOpImageTexelPointer, - SpvOpLoad, - SpvOpAccessChain, - SpvOpInBoundsAccessChain, - SpvOpArrayLength, - SpvOpVectorExtractDynamic, - SpvOpVectorInsertDynamic, - SpvOpVectorShuffle, - SpvOpCompositeConstruct, - SpvOpCompositeExtract, - SpvOpCompositeInsert, - SpvOpCopyObject, - SpvOpTranspose, - SpvOpSampledImage, - SpvOpImageSampleImplicitLod, - SpvOpImageSampleExplicitLod, - SpvOpImageSampleDrefImplicitLod, - SpvOpImageSampleDrefExplicitLod, - SpvOpImageSampleProjImplicitLod, - SpvOpImageSampleProjExplicitLod, - SpvOpImageSampleProjDrefImplicitLod, - SpvOpImageSampleProjDrefExplicitLod, - SpvOpImageFetch, - SpvOpImageGather, - SpvOpImageDrefGather, - SpvOpImageRead, - SpvOpImage, - SpvOpConvertFToU, - SpvOpConvertFToS, - SpvOpConvertSToF, - SpvOpConvertUToF, - SpvOpUConvert, - SpvOpSConvert, - SpvOpFConvert, - SpvOpQuantizeToF16, - SpvOpBitcast, - SpvOpSNegate, - SpvOpFNegate, - SpvOpIAdd, - SpvOpFAdd, - SpvOpISub, - SpvOpFSub, - SpvOpIMul, - SpvOpFMul, - SpvOpUDiv, - SpvOpSDiv, - SpvOpFDiv, - SpvOpUMod, - SpvOpSRem, - SpvOpSMod, - SpvOpFRem, - SpvOpFMod, - SpvOpVectorTimesScalar, - SpvOpMatrixTimesScalar, - SpvOpVectorTimesMatrix, - SpvOpMatrixTimesVector, - SpvOpMatrixTimesMatrix, - SpvOpOuterProduct, - SpvOpDot, - SpvOpIAddCarry, - SpvOpISubBorrow, - SpvOpUMulExtended, - SpvOpSMulExtended, - SpvOpAny, - SpvOpAll, - SpvOpIsNan, - SpvOpIsInf, - SpvOpLogicalEqual, - SpvOpLogicalNotEqual, - SpvOpLogicalOr, - SpvOpLogicalAnd, - SpvOpLogicalNot, - SpvOpSelect, - SpvOpIEqual, - SpvOpINotEqual, - SpvOpUGreaterThan, - SpvOpSGreaterThan, - SpvOpUGreaterThanEqual, - SpvOpSGreaterThanEqual, - SpvOpULessThan, - SpvOpSLessThan, - SpvOpULessThanEqual, - SpvOpSLessThanEqual, - SpvOpFOrdEqual, - SpvOpFUnordEqual, - SpvOpFOrdNotEqual, - SpvOpFUnordNotEqual, - SpvOpFOrdLessThan, - SpvOpFUnordLessThan, - SpvOpFOrdGreaterThan, - SpvOpFUnordGreaterThan, - SpvOpFOrdLessThanEqual, - SpvOpFUnordLessThanEqual, - SpvOpFOrdGreaterThanEqual, - SpvOpFUnordGreaterThanEqual, - SpvOpShiftRightLogical, - SpvOpShiftRightArithmetic, - SpvOpShiftLeftLogical, - SpvOpBitwiseOr, - SpvOpBitwiseXor, - SpvOpBitwiseAnd, - SpvOpNot, - SpvOpBitFieldInsert, - SpvOpBitFieldSExtract, - SpvOpBitFieldUExtract, - SpvOpBitReverse, - SpvOpBitCount, - SpvOpPhi, - SpvOpImageSparseSampleImplicitLod, - SpvOpImageSparseSampleExplicitLod, - SpvOpImageSparseSampleDrefImplicitLod, - SpvOpImageSparseSampleDrefExplicitLod, - SpvOpImageSparseSampleProjImplicitLod, - SpvOpImageSparseSampleProjExplicitLod, - SpvOpImageSparseSampleProjDrefImplicitLod, - SpvOpImageSparseSampleProjDrefExplicitLod, - SpvOpImageSparseFetch, - SpvOpImageSparseGather, - SpvOpImageSparseDrefGather, - SpvOpImageSparseTexelsResident, - SpvOpImageSparseRead, - SpvOpSizeOf - // TODO(dneto): Add instructions enabled by ImageQuery - }); + combinator_ops_[0].insert({SpvOpNop, + SpvOpUndef, + SpvOpConstant, + SpvOpConstantTrue, + SpvOpConstantFalse, + SpvOpConstantComposite, + SpvOpConstantSampler, + SpvOpConstantNull, + SpvOpTypeVoid, + SpvOpTypeBool, + SpvOpTypeInt, + SpvOpTypeFloat, + SpvOpTypeVector, + SpvOpTypeMatrix, + SpvOpTypeImage, + SpvOpTypeSampler, + SpvOpTypeSampledImage, + SpvOpTypeArray, + SpvOpTypeRuntimeArray, + SpvOpTypeStruct, + SpvOpTypeOpaque, + SpvOpTypePointer, + SpvOpTypeFunction, + SpvOpTypeEvent, + SpvOpTypeDeviceEvent, + SpvOpTypeReserveId, + SpvOpTypeQueue, + SpvOpTypePipe, + SpvOpTypeForwardPointer, + SpvOpVariable, + SpvOpImageTexelPointer, + SpvOpLoad, + SpvOpAccessChain, + SpvOpInBoundsAccessChain, + SpvOpArrayLength, + SpvOpVectorExtractDynamic, + SpvOpVectorInsertDynamic, + SpvOpVectorShuffle, + SpvOpCompositeConstruct, + SpvOpCompositeExtract, + SpvOpCompositeInsert, + SpvOpCopyObject, + SpvOpTranspose, + SpvOpSampledImage, + SpvOpImageSampleImplicitLod, + SpvOpImageSampleExplicitLod, + SpvOpImageSampleDrefImplicitLod, + SpvOpImageSampleDrefExplicitLod, + SpvOpImageSampleProjImplicitLod, + SpvOpImageSampleProjExplicitLod, + SpvOpImageSampleProjDrefImplicitLod, + SpvOpImageSampleProjDrefExplicitLod, + SpvOpImageFetch, + SpvOpImageGather, + SpvOpImageDrefGather, + SpvOpImageRead, + SpvOpImage, + SpvOpImageQueryFormat, + SpvOpImageQueryOrder, + SpvOpImageQuerySizeLod, + SpvOpImageQuerySize, + SpvOpImageQueryLevels, + SpvOpImageQuerySamples, + SpvOpConvertFToU, + SpvOpConvertFToS, + SpvOpConvertSToF, + SpvOpConvertUToF, + SpvOpUConvert, + SpvOpSConvert, + SpvOpFConvert, + SpvOpQuantizeToF16, + SpvOpBitcast, + SpvOpSNegate, + SpvOpFNegate, + SpvOpIAdd, + SpvOpFAdd, + SpvOpISub, + SpvOpFSub, + SpvOpIMul, + SpvOpFMul, + SpvOpUDiv, + SpvOpSDiv, + SpvOpFDiv, + SpvOpUMod, + SpvOpSRem, + SpvOpSMod, + SpvOpFRem, + SpvOpFMod, + SpvOpVectorTimesScalar, + SpvOpMatrixTimesScalar, + SpvOpVectorTimesMatrix, + SpvOpMatrixTimesVector, + SpvOpMatrixTimesMatrix, + SpvOpOuterProduct, + SpvOpDot, + SpvOpIAddCarry, + SpvOpISubBorrow, + SpvOpUMulExtended, + SpvOpSMulExtended, + SpvOpAny, + SpvOpAll, + SpvOpIsNan, + SpvOpIsInf, + SpvOpLogicalEqual, + SpvOpLogicalNotEqual, + SpvOpLogicalOr, + SpvOpLogicalAnd, + SpvOpLogicalNot, + SpvOpSelect, + SpvOpIEqual, + SpvOpINotEqual, + SpvOpUGreaterThan, + SpvOpSGreaterThan, + SpvOpUGreaterThanEqual, + SpvOpSGreaterThanEqual, + SpvOpULessThan, + SpvOpSLessThan, + SpvOpULessThanEqual, + SpvOpSLessThanEqual, + SpvOpFOrdEqual, + SpvOpFUnordEqual, + SpvOpFOrdNotEqual, + SpvOpFUnordNotEqual, + SpvOpFOrdLessThan, + SpvOpFUnordLessThan, + SpvOpFOrdGreaterThan, + SpvOpFUnordGreaterThan, + SpvOpFOrdLessThanEqual, + SpvOpFUnordLessThanEqual, + SpvOpFOrdGreaterThanEqual, + SpvOpFUnordGreaterThanEqual, + SpvOpShiftRightLogical, + SpvOpShiftRightArithmetic, + SpvOpShiftLeftLogical, + SpvOpBitwiseOr, + SpvOpBitwiseXor, + SpvOpBitwiseAnd, + SpvOpNot, + SpvOpBitFieldInsert, + SpvOpBitFieldSExtract, + SpvOpBitFieldUExtract, + SpvOpBitReverse, + SpvOpBitCount, + SpvOpPhi, + SpvOpImageSparseSampleImplicitLod, + SpvOpImageSparseSampleExplicitLod, + SpvOpImageSparseSampleDrefImplicitLod, + SpvOpImageSparseSampleDrefExplicitLod, + SpvOpImageSparseSampleProjImplicitLod, + SpvOpImageSparseSampleProjExplicitLod, + SpvOpImageSparseSampleProjDrefImplicitLod, + SpvOpImageSparseSampleProjDrefExplicitLod, + SpvOpImageSparseFetch, + SpvOpImageSparseGather, + SpvOpImageSparseDrefGather, + SpvOpImageSparseTexelsResident, + SpvOpImageSparseRead, + SpvOpSizeOf}); } } -void IRContext::AddCombinatorsForExtension(ir::Instruction* extension) { +void IRContext::AddCombinatorsForExtension(Instruction* extension) { assert(extension->opcode() == SpvOpExtInstImport && "Expecting an import of an extension's instruction set."); const char* extension_name = @@ -545,15 +558,16 @@ void IRContext::RemoveFromIdToName(const Instruction* inst) { } } -ir::LoopDescriptor* IRContext::GetLoopDescriptor(const ir::Function* f) { +LoopDescriptor* IRContext::GetLoopDescriptor(const Function* f) { if (!AreAnalysesValid(kAnalysisLoopAnalysis)) { ResetLoopAnalysis(); } - std::unordered_map::iterator it = + std::unordered_map::iterator it = loop_descriptors_.find(f); if (it == loop_descriptors_.end()) { - return &loop_descriptors_.emplace(std::make_pair(f, ir::LoopDescriptor(f))) + return &loop_descriptors_ + .emplace(std::make_pair(f, LoopDescriptor(this, f))) .first->second; } @@ -561,40 +575,38 @@ ir::LoopDescriptor* IRContext::GetLoopDescriptor(const ir::Function* f) { } // Gets the dominator analysis for function |f|. -opt::DominatorAnalysis* IRContext::GetDominatorAnalysis(const ir::Function* f, - const ir::CFG& in_cfg) { +DominatorAnalysis* IRContext::GetDominatorAnalysis(const Function* f) { if (!AreAnalysesValid(kAnalysisDominatorAnalysis)) { ResetDominatorAnalysis(); } if (dominator_trees_.find(f) == dominator_trees_.end()) { - dominator_trees_[f].InitializeTree(f, in_cfg); + dominator_trees_[f].InitializeTree(*cfg(), f); } return &dominator_trees_[f]; } // Gets the postdominator analysis for function |f|. -opt::PostDominatorAnalysis* IRContext::GetPostDominatorAnalysis( - const ir::Function* f, const ir::CFG& in_cfg) { +PostDominatorAnalysis* IRContext::GetPostDominatorAnalysis(const Function* f) { if (!AreAnalysesValid(kAnalysisDominatorAnalysis)) { ResetDominatorAnalysis(); } if (post_dominator_trees_.find(f) == post_dominator_trees_.end()) { - post_dominator_trees_[f].InitializeTree(f, in_cfg); + post_dominator_trees_[f].InitializeTree(*cfg(), f); } return &post_dominator_trees_[f]; } -bool ir::IRContext::CheckCFG() { +bool IRContext::CheckCFG() { std::unordered_map> real_preds; if (!AreAnalysesValid(kAnalysisCFG)) { return true; } - for (ir::Function& function : *module()) { + for (Function& function : *module()) { for (const auto& bb : function) { bb.ForEachSuccessorLabel([&bb, &real_preds](const uint32_t lab_id) { real_preds[lab_id].push_back(bb.id()); @@ -639,5 +651,5 @@ bool ir::IRContext::CheckCFG() { return true; } -} // namespace ir +} // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/ir_context.h b/3rdparty/spirv-tools/source/opt/ir_context.h index bb44b4997..a9d892fa2 100644 --- a/3rdparty/spirv-tools/source/opt/ir_context.h +++ b/3rdparty/spirv-tools/source/opt/ir_context.h @@ -12,28 +12,37 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SPIRV_TOOLS_IR_CONTEXT_H -#define SPIRV_TOOLS_IR_CONTEXT_H - -#include "assembly_grammar.h" -#include "cfg.h" -#include "constants.h" -#include "decoration_manager.h" -#include "def_use_manager.h" -#include "dominator_analysis.h" -#include "feature_manager.h" -#include "loop_descriptor.h" -#include "module.h" -#include "scalar_analysis.h" -#include "type_manager.h" +#ifndef SOURCE_OPT_IR_CONTEXT_H_ +#define SOURCE_OPT_IR_CONTEXT_H_ #include #include #include +#include +#include +#include #include +#include +#include + +#include "source/assembly_grammar.h" +#include "source/opt/cfg.h" +#include "source/opt/constants.h" +#include "source/opt/decoration_manager.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/feature_manager.h" +#include "source/opt/fold.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/module.h" +#include "source/opt/register_pressure.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/type_manager.h" +#include "source/opt/value_number_table.h" +#include "source/util/make_unique.h" namespace spvtools { -namespace ir { +namespace opt { class IRContext { public: @@ -60,7 +69,9 @@ class IRContext { kAnalysisLoopAnalysis = 1 << 6, kAnalysisNameMap = 1 << 7, kAnalysisScalarEvolution = 1 << 8, - kAnalysisEnd = 1 << 9 + kAnalysisRegisterPressure = 1 << 9, + kAnalysisValueNumberTable = 1 << 10, + kAnalysisEnd = 1 << 11 }; friend inline Analysis operator|(Analysis lhs, Analysis rhs); @@ -69,7 +80,7 @@ class IRContext { friend inline Analysis& operator<<=(Analysis& a, int shift); // Creates an |IRContext| that contains an owned |Module| - IRContext(spv_target_env env, spvtools::MessageConsumer c) + IRContext(spv_target_env env, MessageConsumer c) : syntax_context_(spvContextCreate(env)), grammar_(syntax_context_), unique_id_(0), @@ -80,12 +91,11 @@ class IRContext { constant_mgr_(nullptr), type_mgr_(nullptr), id_to_name_(nullptr) { - libspirv::SetContextMessageConsumer(syntax_context_, consumer_); + SetContextMessageConsumer(syntax_context_, consumer_); module_->SetContext(this); } - IRContext(spv_target_env env, std::unique_ptr&& m, - spvtools::MessageConsumer c) + IRContext(spv_target_env env, std::unique_ptr&& m, MessageConsumer c) : syntax_context_(spvContextCreate(env)), grammar_(syntax_context_), unique_id_(0), @@ -95,7 +105,7 @@ class IRContext { valid_analyses_(kAnalysisNone), type_mgr_(nullptr), id_to_name_(nullptr) { - libspirv::SetContextMessageConsumer(syntax_context_, consumer_); + SetContextMessageConsumer(syntax_context_, consumer_); module_->SetContext(this); InitializeCombinators(); } @@ -122,8 +132,8 @@ class IRContext { inline IteratorRange capabilities() const; // Iterators for types, constants and global variables instructions. - inline ir::Module::inst_iterator types_values_begin(); - inline ir::Module::inst_iterator types_values_end(); + inline Module::inst_iterator types_values_begin(); + inline Module::inst_iterator types_values_end(); inline IteratorRange types_values(); inline IteratorRange types_values() const; @@ -199,16 +209,34 @@ class IRContext { // Returns a pointer to a def-use manager. If the def-use manager is // invalid, it is rebuilt first. - opt::analysis::DefUseManager* get_def_use_mgr() { + analysis::DefUseManager* get_def_use_mgr() { if (!AreAnalysesValid(kAnalysisDefUse)) { BuildDefUseManager(); } return def_use_mgr_.get(); } + // Returns a pointer to a value number table. If the liveness analysis is + // invalid, it is rebuilt first. + ValueNumberTable* GetValueNumberTable() { + if (!AreAnalysesValid(kAnalysisValueNumberTable)) { + BuildValueNumberTable(); + } + return vn_table_.get(); + } + + // Returns a pointer to a liveness analysis. If the liveness analysis is + // invalid, it is rebuilt first. + LivenessAnalysis* GetLivenessAnalysis() { + if (!AreAnalysesValid(kAnalysisRegisterPressure)) { + BuildRegPressureAnalysis(); + } + return reg_pressure_.get(); + } + // Returns the basic block for instruction |instr|. Re-builds the instruction // block map, if needed. - ir::BasicBlock* get_instr_block(ir::Instruction* instr) { + BasicBlock* get_instr_block(Instruction* instr) { if (!AreAnalysesValid(kAnalysisInstrToBlockMapping)) { BuildInstrToBlockMapping(); } @@ -220,14 +248,14 @@ class IRContext { // needed. // // |id| must be a registered definition. - ir::BasicBlock* get_instr_block(uint32_t id) { - ir::Instruction* def = get_def_use_mgr()->GetDef(id); + BasicBlock* get_instr_block(uint32_t id) { + Instruction* def = get_def_use_mgr()->GetDef(id); return get_instr_block(def); } // Sets the basic block for |inst|. Re-builds the mapping if it has become // invalid. - void set_instr_block(ir::Instruction* inst, ir::BasicBlock* block) { + void set_instr_block(Instruction* inst, BasicBlock* block) { if (AreAnalysesValid(kAnalysisInstrToBlockMapping)) { instr_to_block_[inst] = block; } @@ -235,34 +263,34 @@ class IRContext { // Returns a pointer the decoration manager. If the decoration manger is // invalid, it is rebuilt first. - opt::analysis::DecorationManager* get_decoration_mgr() { + analysis::DecorationManager* get_decoration_mgr() { if (!AreAnalysesValid(kAnalysisDecorations)) { BuildDecorationManager(); } return decoration_mgr_.get(); - }; + } // Returns a pointer to the constant manager. If no constant manager has been // created yet, it creates one. NOTE: Once created, the constant manager // remains active and it is never re-built. - opt::analysis::ConstantManager* get_constant_mgr() { + analysis::ConstantManager* get_constant_mgr() { if (!constant_mgr_) - constant_mgr_.reset(new opt::analysis::ConstantManager(this)); + constant_mgr_ = MakeUnique(this); return constant_mgr_.get(); } // Returns a pointer to the type manager. If no type manager has been created // yet, it creates one. NOTE: Once created, the type manager remains active it // is never re-built. - opt::analysis::TypeManager* get_type_mgr() { + analysis::TypeManager* get_type_mgr() { if (!type_mgr_) - type_mgr_.reset(new opt::analysis::TypeManager(consumer(), this)); + type_mgr_ = MakeUnique(consumer(), this); return type_mgr_.get(); } // Returns a pointer to the scalar evolution analysis. If it is invalid it // will be rebuilt first. - opt::ScalarEvolutionAnalysis* GetScalarEvolutionAnalysis() { + ScalarEvolutionAnalysis* GetScalarEvolutionAnalysis() { if (!AreAnalysesValid(kAnalysisScalarEvolution)) { BuildScalarEvolutionAnalysis(); } @@ -280,12 +308,10 @@ class IRContext { // Sets the message consumer to the given |consumer|. |consumer| which will be // invoked every time there is a message to be communicated to the outside. - void SetMessageConsumer(spvtools::MessageConsumer c) { - consumer_ = std::move(c); - } + void SetMessageConsumer(MessageConsumer c) { consumer_ = std::move(c); } // Returns the reference to the message consumer for this pass. - const spvtools::MessageConsumer& consumer() const { return consumer_; } + const MessageConsumer& consumer() const { return consumer_; } // Rebuilds the analyses in |set| that are invalid. void BuildInvalidAnalyses(Analysis set); @@ -317,7 +343,7 @@ class IRContext { // // Returns a pointer to the instruction after |inst| or |nullptr| if no such // instruction exists. - Instruction* KillInst(ir::Instruction* inst); + Instruction* KillInst(Instruction* inst); // Returns true if all of the given analyses are valid. bool AreAnalysesValid(Analysis set) { return (set & valid_analyses_) == set; } @@ -351,7 +377,7 @@ class IRContext { void KillNamesAndDecorates(uint32_t id); // Kill all name and decorate ops targeting the result id of |inst|. - void KillNamesAndDecorates(ir::Instruction* inst); + void KillNamesAndDecorates(Instruction* inst); // Returns the next unique id for use by an instruction. inline uint32_t TakeNextUniqueId() { @@ -363,7 +389,7 @@ class IRContext { // Returns true if |inst| is a combinator in the current context. // |combinator_ops_| is built if it has not been already. - inline bool IsCombinatorInstruction(ir::Instruction* inst) { + inline bool IsCombinatorInstruction(const Instruction* inst) { if (!AreAnalysesValid(kAnalysisCombinators)) { InitializeCombinators(); } @@ -380,7 +406,7 @@ class IRContext { } // Returns a pointer to the CFG for all the functions in |module_|. - ir::CFG* cfg() { + CFG* cfg() { if (!AreAnalysesValid(kAnalysisCFG)) { BuildCFG(); } @@ -388,30 +414,28 @@ class IRContext { } // Gets the loop descriptor for function |f|. - ir::LoopDescriptor* GetLoopDescriptor(const ir::Function* f); + LoopDescriptor* GetLoopDescriptor(const Function* f); // Gets the dominator analysis for function |f|. - opt::DominatorAnalysis* GetDominatorAnalysis(const ir::Function* f, - const ir::CFG&); + DominatorAnalysis* GetDominatorAnalysis(const Function* f); // Gets the postdominator analysis for function |f|. - opt::PostDominatorAnalysis* GetPostDominatorAnalysis(const ir::Function* f, - const ir::CFG&); + PostDominatorAnalysis* GetPostDominatorAnalysis(const Function* f); // Remove the dominator tree of |f| from the cache. - inline void RemoveDominatorAnalysis(const ir::Function* f) { + inline void RemoveDominatorAnalysis(const Function* f) { dominator_trees_.erase(f); } // Remove the postdominator tree of |f| from the cache. - inline void RemovePostDominatorAnalysis(const ir::Function* f) { + inline void RemovePostDominatorAnalysis(const Function* f) { post_dominator_trees_.erase(f); } // Return the next available SSA id and increment it. inline uint32_t TakeNextId() { return module()->TakeNextIdBound(); } - opt::FeatureManager* get_feature_mgr() { + FeatureManager* get_feature_mgr() { if (!feature_mgr_.get()) { AnalyzeFeatures(); } @@ -419,16 +443,23 @@ class IRContext { } // Returns the grammar for this context. - const libspirv::AssemblyGrammar& grammar() const { return grammar_; } + const AssemblyGrammar& grammar() const { return grammar_; } // If |inst| has not yet been analysed by the def-use manager, then analyse // its definitions and uses. inline void UpdateDefUse(Instruction* inst); + const InstructionFolder& get_instruction_folder() { + if (!inst_folder_) { + inst_folder_ = MakeUnique(this); + } + return *inst_folder_; + } + private: // Builds the def-use manager from scratch, even if it was already valid. void BuildDefUseManager() { - def_use_mgr_.reset(new opt::analysis::DefUseManager(module())); + def_use_mgr_ = MakeUnique(module()); valid_analyses_ = valid_analyses_ | kAnalysisDefUse; } @@ -437,7 +468,7 @@ class IRContext { instr_to_block_.clear(); for (auto& fn : *module_) { for (auto& block : fn) { - block.ForEachInst([this, &block](ir::Instruction* inst) { + block.ForEachInst([this, &block](Instruction* inst) { instr_to_block_[inst] = █ }); } @@ -446,20 +477,33 @@ class IRContext { } void BuildDecorationManager() { - decoration_mgr_.reset(new opt::analysis::DecorationManager(module())); + decoration_mgr_ = MakeUnique(module()); valid_analyses_ = valid_analyses_ | kAnalysisDecorations; } void BuildCFG() { - cfg_.reset(new ir::CFG(module())); + cfg_ = MakeUnique(module()); valid_analyses_ = valid_analyses_ | kAnalysisCFG; } void BuildScalarEvolutionAnalysis() { - scalar_evolution_analysis_.reset(new opt::ScalarEvolutionAnalysis(this)); + scalar_evolution_analysis_ = MakeUnique(this); valid_analyses_ = valid_analyses_ | kAnalysisScalarEvolution; } + // Builds the liveness analysis from scratch, even if it was already valid. + void BuildRegPressureAnalysis() { + reg_pressure_ = MakeUnique(this); + valid_analyses_ = valid_analyses_ | kAnalysisRegisterPressure; + } + + // Builds the value number table analysis from scratch, even if it was already + // valid. + void BuildValueNumberTable() { + vn_table_ = MakeUnique(this); + valid_analyses_ = valid_analyses_ | kAnalysisValueNumberTable; + } + // Removes all computed dominator and post-dominator trees. This will force // the context to rebuild the trees on demand. void ResetDominatorAnalysis() { @@ -478,7 +522,7 @@ class IRContext { // Analyzes the features in the owned module. Builds the manager if required. void AnalyzeFeatures() { - feature_mgr_.reset(new opt::FeatureManager(grammar_)); + feature_mgr_ = MakeUnique(grammar_); feature_mgr_->Analyze(module()); } @@ -490,7 +534,7 @@ class IRContext { void AddCombinatorsForCapability(uint32_t capability); // Add the combinator opcode for the given extension to combinator_ops_. - void AddCombinatorsForExtension(ir::Instruction* extension); + void AddCombinatorsForExtension(Instruction* extension); // Remove |inst| from |id_to_name_| if it is in map. void RemoveFromIdToName(const Instruction* inst); @@ -504,7 +548,7 @@ class IRContext { spv_context syntax_context_; // Auxiliary object for querying SPIR-V grammar facts. - libspirv::AssemblyGrammar grammar_; + AssemblyGrammar grammar_; // An unique identifier for instructions in |module_|. Can be used to order // instructions in a container. @@ -517,21 +561,21 @@ class IRContext { std::unique_ptr module_; // A message consumer for diagnostics. - spvtools::MessageConsumer consumer_; + MessageConsumer consumer_; // The def-use manager for |module_|. - std::unique_ptr def_use_mgr_; + std::unique_ptr def_use_mgr_; // The instruction decoration manager for |module_|. - std::unique_ptr decoration_mgr_; - std::unique_ptr feature_mgr_; + std::unique_ptr decoration_mgr_; + std::unique_ptr feature_mgr_; // A map from instructions the the basic block they belong to. This mapping is // built on-demand when get_instr_block() is called. // // NOTE: Do not traverse this map. Ever. Use the function and basic block // iterators to traverse instructions. - std::unordered_map instr_to_block_; + std::unordered_map instr_to_block_; // A bitset indicating which analyes are currently valid. Analysis valid_analyses_; @@ -541,55 +585,59 @@ class IRContext { std::unordered_map> combinator_ops_; // The CFG for all the functions in |module_|. - std::unique_ptr cfg_; + std::unique_ptr cfg_; // Each function in the module will create its own dominator tree. We cache // the result so it doesn't need to be rebuilt each time. - std::map dominator_trees_; - std::map - post_dominator_trees_; + std::map dominator_trees_; + std::map post_dominator_trees_; // Cache of loop descriptors for each function. - std::unordered_map loop_descriptors_; + std::unordered_map loop_descriptors_; // Constant manager for |module_|. - std::unique_ptr constant_mgr_; + std::unique_ptr constant_mgr_; // Type manager for |module_|. - std::unique_ptr type_mgr_; + std::unique_ptr type_mgr_; // A map from an id to its corresponding OpName and OpMemberName instructions. std::unique_ptr> id_to_name_; // The cache scalar evolution analysis node. - std::unique_ptr scalar_evolution_analysis_; + std::unique_ptr scalar_evolution_analysis_; + + // The liveness analysis |module_|. + std::unique_ptr reg_pressure_; + + std::unique_ptr vn_table_; + + std::unique_ptr inst_folder_; }; -inline ir::IRContext::Analysis operator|(ir::IRContext::Analysis lhs, - ir::IRContext::Analysis rhs) { - return static_cast(static_cast(lhs) | - static_cast(rhs)); +inline IRContext::Analysis operator|(IRContext::Analysis lhs, + IRContext::Analysis rhs) { + return static_cast(static_cast(lhs) | + static_cast(rhs)); } -inline ir::IRContext::Analysis& operator|=(ir::IRContext::Analysis& lhs, - ir::IRContext::Analysis rhs) { - lhs = static_cast(static_cast(lhs) | - static_cast(rhs)); +inline IRContext::Analysis& operator|=(IRContext::Analysis& lhs, + IRContext::Analysis rhs) { + lhs = static_cast(static_cast(lhs) | + static_cast(rhs)); return lhs; } -inline ir::IRContext::Analysis operator<<(ir::IRContext::Analysis a, - int shift) { - return static_cast(static_cast(a) << shift); +inline IRContext::Analysis operator<<(IRContext::Analysis a, int shift) { + return static_cast(static_cast(a) << shift); } -inline ir::IRContext::Analysis& operator<<=(ir::IRContext::Analysis& a, - int shift) { - a = static_cast(static_cast(a) << shift); +inline IRContext::Analysis& operator<<=(IRContext::Analysis& a, int shift) { + a = static_cast(static_cast(a) << shift); return a; } -std::vector spvtools::ir::IRContext::GetConstants() { +std::vector IRContext::GetConstants() { return module()->GetConstants(); } @@ -629,11 +677,11 @@ IteratorRange IRContext::capabilities() const { return ((const Module*)module())->capabilities(); } -ir::Module::inst_iterator IRContext::types_values_begin() { +Module::inst_iterator IRContext::types_values_begin() { return module()->types_values_begin(); } -ir::Module::inst_iterator IRContext::types_values_end() { +Module::inst_iterator IRContext::types_values_end() { return module()->types_values_end(); } @@ -757,7 +805,7 @@ void IRContext::AddAnnotationInst(std::unique_ptr&& a) { void IRContext::AddType(std::unique_ptr&& t) { module()->AddType(std::move(t)); if (AreAnalysesValid(kAnalysisDefUse)) { - get_def_use_mgr()->AnalyzeInstDef(&*(--types_values_end())); + get_def_use_mgr()->AnalyzeInstDefUse(&*(--types_values_end())); } } @@ -785,7 +833,7 @@ void IRContext::UpdateDefUse(Instruction* inst) { } void IRContext::BuildIdToNameMap() { - id_to_name_.reset(new std::multimap()); + id_to_name_ = MakeUnique>(); for (Instruction& debug_inst : debugs2()) { if (debug_inst.opcode() == SpvOpMemberName || debug_inst.opcode() == SpvOpName) { @@ -804,6 +852,7 @@ IRContext::GetNames(uint32_t id) { return make_range(std::move(result.first), std::move(result.second)); } -} // namespace ir +} // namespace opt } // namespace spvtools -#endif // SPIRV_TOOLS_IR_CONTEXT_H + +#endif // SOURCE_OPT_IR_CONTEXT_H_ diff --git a/3rdparty/spirv-tools/source/opt/ir_loader.cpp b/3rdparty/spirv-tools/source/opt/ir_loader.cpp index a526d6f9e..46e2bee42 100644 --- a/3rdparty/spirv-tools/source/opt/ir_loader.cpp +++ b/3rdparty/spirv-tools/source/opt/ir_loader.cpp @@ -12,13 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "ir_loader.h" +#include "source/opt/ir_loader.h" -#include "log.h" -#include "reflect.h" +#include + +#include "source/opt/log.h" +#include "source/opt/reflect.h" +#include "source/util/make_unique.h" namespace spvtools { -namespace ir { +namespace opt { IrLoader::IrLoader(const MessageConsumer& consumer, Module* m) : consumer_(consumer), @@ -48,7 +51,7 @@ bool IrLoader::AddInstruction(const spv_parsed_instruction_t* inst) { Error(consumer_, src, loc, "function inside function"); return false; } - function_.reset(new Function(std::move(spv_inst))); + function_ = MakeUnique(std::move(spv_inst)); } else if (opcode == SpvOpFunctionEnd) { if (function_ == nullptr) { Error(consumer_, src, loc, @@ -71,7 +74,7 @@ bool IrLoader::AddInstruction(const spv_parsed_instruction_t* inst) { Error(consumer_, src, loc, "OpLabel inside basic block"); return false; } - block_.reset(new BasicBlock(std::move(spv_inst))); + block_ = MakeUnique(std::move(spv_inst)); } else if (IsTerminatorInst(opcode)) { if (function_ == nullptr) { Error(consumer_, src, loc, "terminator instruction outside function"); @@ -153,9 +156,8 @@ void IrLoader::EndModule() { } for (auto& function : *module_) { for (auto& bb : function) bb.SetParent(&function); - function.SetParent(module_); } } -} // namespace ir +} // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/ir_loader.h b/3rdparty/spirv-tools/source/opt/ir_loader.h index 2f0ca8b0b..940d7b0db 100644 --- a/3rdparty/spirv-tools/source/opt/ir_loader.h +++ b/3rdparty/spirv-tools/source/opt/ir_loader.h @@ -12,18 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_IR_LOADER_H_ -#define LIBSPIRV_OPT_IR_LOADER_H_ +#ifndef SOURCE_OPT_IR_LOADER_H_ +#define SOURCE_OPT_IR_LOADER_H_ #include +#include +#include -#include "basic_block.h" -#include "instruction.h" -#include "module.h" +#include "source/opt/basic_block.h" +#include "source/opt/instruction.h" +#include "source/opt/module.h" #include "spirv-tools/libspirv.hpp" namespace spvtools { -namespace ir { +namespace opt { // Loader class for constructing SPIR-V in-memory IR representation. Methods in // this class are designed to work with the interface for spvBinaryParse() in @@ -78,7 +80,7 @@ class IrLoader { std::vector dbg_line_info_; }; -} // namespace ir +} // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_IR_LOADER_H_ +#endif // SOURCE_OPT_IR_LOADER_H_ diff --git a/3rdparty/spirv-tools/source/opt/iterator.h b/3rdparty/spirv-tools/source/opt/iterator.h index d43dfbef7..444d457c5 100644 --- a/3rdparty/spirv-tools/source/opt/iterator.h +++ b/3rdparty/spirv-tools/source/opt/iterator.h @@ -12,17 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_ITERATOR_H_ -#define LIBSPIRV_OPT_ITERATOR_H_ +#ifndef SOURCE_OPT_ITERATOR_H_ +#define SOURCE_OPT_ITERATOR_H_ #include // for ptrdiff_t #include #include #include +#include #include namespace spvtools { -namespace ir { +namespace opt { // An ad hoc iterator class for std::vector>. The // purpose of this iterator class is to provide transparent access to those @@ -166,6 +167,105 @@ inline IteratorRange make_const_range( IteratorType(&container, container.cend())}; } +// Wrapping iterator class that only consider elements that satisfy the given +// predicate |Predicate|. When moving to the next element of the iterator, the +// FilterIterator will iterate over the range until it finds an element that +// satisfies |Predicate| or reaches the end of the iterator. +// +// Currently this iterator is always an input iterator. +template +class FilterIterator + : public std::iterator< + std::input_iterator_tag, typename SubIterator::value_type, + typename SubIterator::difference_type, typename SubIterator::pointer, + typename SubIterator::reference> { + public: + // Iterator interface. + using iterator_category = typename SubIterator::iterator_category; + using value_type = typename SubIterator::value_type; + using pointer = typename SubIterator::pointer; + using reference = typename SubIterator::reference; + using difference_type = typename SubIterator::difference_type; + + using Range = IteratorRange; + + FilterIterator(const IteratorRange& iteration_range, + Predicate predicate) + : cur_(iteration_range.begin()), + end_(iteration_range.end()), + predicate_(predicate) { + if (!IsPredicateSatisfied()) { + MoveToNextPosition(); + } + } + + FilterIterator(const SubIterator& end, Predicate predicate) + : FilterIterator({end, end}, predicate) {} + + inline FilterIterator& operator++() { + MoveToNextPosition(); + return *this; + } + inline FilterIterator operator++(int) { + FilterIterator old = *this; + MoveToNextPosition(); + return old; + } + + reference operator*() const { return *cur_; } + pointer operator->() { return &*cur_; } + + inline bool operator==(const FilterIterator& rhs) const { + return cur_ == rhs.cur_ && end_ == rhs.end_; + } + inline bool operator!=(const FilterIterator& rhs) const { + return !(*this == rhs); + } + + // Returns the underlying iterator. + SubIterator Get() const { return cur_; } + + // Returns the sentinel iterator. + FilterIterator GetEnd() const { return FilterIterator(end_, predicate_); } + + private: + // Returns true if the predicate is satisfied or the current iterator reached + // the end. + bool IsPredicateSatisfied() { return cur_ == end_ || predicate_(*cur_); } + + void MoveToNextPosition() { + if (cur_ == end_) return; + + do { + ++cur_; + } while (!IsPredicateSatisfied()); + } + + SubIterator cur_; + SubIterator end_; + Predicate predicate_; +}; + +template +FilterIterator MakeFilterIterator( + const IteratorRange& sub_iterator_range, Predicate predicate) { + return FilterIterator(sub_iterator_range, predicate); +} + +template +FilterIterator MakeFilterIterator( + const SubIterator& begin, const SubIterator& end, Predicate predicate) { + return MakeFilterIterator(make_range(begin, end), predicate); +} + +template +typename FilterIterator::Range MakeFilterIteratorRange( + const SubIterator& begin, const SubIterator& end, Predicate predicate) { + return typename FilterIterator::Range( + MakeFilterIterator(begin, end, predicate), + MakeFilterIterator(end, end, predicate)); +} + template inline UptrVectorIterator& UptrVectorIterator::operator++() { ++iterator_; @@ -252,7 +352,7 @@ inline return UptrVectorIterator(container_, container_->begin() + index); } -} // namespace ir +} // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_ITERATOR_H_ +#endif // SOURCE_OPT_ITERATOR_H_ diff --git a/3rdparty/spirv-tools/source/opt/licm_pass.cpp b/3rdparty/spirv-tools/source/opt/licm_pass.cpp index 7faa21d82..d8256679e 100644 --- a/3rdparty/spirv-tools/source/opt/licm_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/licm_pass.cpp @@ -12,44 +12,39 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "opt/licm_pass.h" -#include "opt/module.h" -#include "opt/pass.h" +#include "source/opt/licm_pass.h" #include #include +#include "source/opt/module.h" +#include "source/opt/pass.h" + namespace spvtools { namespace opt { -Pass::Status LICMPass::Process(ir::IRContext* c) { - InitializeProcessing(c); - bool modified = false; - - if (c != nullptr) { - modified = ProcessIRContext(); - } - - return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +Pass::Status LICMPass::Process() { + return ProcessIRContext() ? Status::SuccessWithChange + : Status::SuccessWithoutChange; } bool LICMPass::ProcessIRContext() { bool modified = false; - ir::Module* module = get_module(); + Module* module = get_module(); // Process each function in the module - for (ir::Function& f : *module) { + for (Function& f : *module) { modified |= ProcessFunction(&f); } return modified; } -bool LICMPass::ProcessFunction(ir::Function* f) { +bool LICMPass::ProcessFunction(Function* f) { bool modified = false; - ir::LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f); + LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f); // Process each loop in the function - for (ir::Loop& loop : *loop_descriptor) { + for (Loop& loop : *loop_descriptor) { // Ignore nested loops, as we will process them in order in ProcessLoop if (loop.IsNested()) { continue; @@ -59,19 +54,19 @@ bool LICMPass::ProcessFunction(ir::Function* f) { return modified; } -bool LICMPass::ProcessLoop(ir::Loop* loop, ir::Function* f) { +bool LICMPass::ProcessLoop(Loop* loop, Function* f) { bool modified = false; // Process all nested loops first - for (ir::Loop* nested_loop : *loop) { + for (Loop* nested_loop : *loop) { modified |= ProcessLoop(nested_loop, f); } - std::vector loop_bbs{}; + std::vector loop_bbs{}; modified |= AnalyseAndHoistFromBB(loop, f, loop->GetHeaderBlock(), &loop_bbs); for (size_t i = 0; i < loop_bbs.size(); ++i) { - ir::BasicBlock* bb = loop_bbs[i]; + BasicBlock* bb = loop_bbs[i]; // do not delete the element modified |= AnalyseAndHoistFromBB(loop, f, bb, &loop_bbs); } @@ -79,12 +74,11 @@ bool LICMPass::ProcessLoop(ir::Loop* loop, ir::Function* f) { return modified; } -bool LICMPass::AnalyseAndHoistFromBB(ir::Loop* loop, ir::Function* f, - ir::BasicBlock* bb, - std::vector* loop_bbs) { +bool LICMPass::AnalyseAndHoistFromBB(Loop* loop, Function* f, BasicBlock* bb, + std::vector* loop_bbs) { bool modified = false; - std::function hoist_inst = - [this, &loop, &modified](ir::Instruction* inst) { + std::function hoist_inst = + [this, &loop, &modified](Instruction* inst) { if (loop->ShouldHoistInstruction(this->context(), inst)) { HoistInstruction(loop, inst); modified = true; @@ -95,12 +89,10 @@ bool LICMPass::AnalyseAndHoistFromBB(ir::Loop* loop, ir::Function* f, bb->ForEachInst(hoist_inst, false); } - opt::DominatorAnalysis* dom_analysis = - context()->GetDominatorAnalysis(f, *cfg()); - opt::DominatorTree& dom_tree = dom_analysis->GetDomTree(); + DominatorAnalysis* dom_analysis = context()->GetDominatorAnalysis(f); + DominatorTree& dom_tree = dom_analysis->GetDomTree(); - for (opt::DominatorTreeNode* child_dom_tree_node : - *dom_tree.GetTreeNode(bb)) { + for (DominatorTreeNode* child_dom_tree_node : *dom_tree.GetTreeNode(bb)) { if (loop->IsInsideLoop(child_dom_tree_node->bb_)) { loop_bbs->push_back(child_dom_tree_node->bb_); } @@ -109,14 +101,14 @@ bool LICMPass::AnalyseAndHoistFromBB(ir::Loop* loop, ir::Function* f, return modified; } -bool LICMPass::IsImmediatelyContainedInLoop(ir::Loop* loop, ir::Function* f, - ir::BasicBlock* bb) { - ir::LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f); +bool LICMPass::IsImmediatelyContainedInLoop(Loop* loop, Function* f, + BasicBlock* bb) { + LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f); return loop == (*loop_descriptor)[bb->id()]; } -void LICMPass::HoistInstruction(ir::Loop* loop, ir::Instruction* inst) { - ir::BasicBlock* pre_header_bb = loop->GetOrCreatePreHeaderBlock(); +void LICMPass::HoistInstruction(Loop* loop, Instruction* inst) { + BasicBlock* pre_header_bb = loop->GetOrCreatePreHeaderBlock(); inst->InsertBefore(std::move(&(*pre_header_bb->tail()))); context()->set_instr_block(inst, pre_header_bb); } diff --git a/3rdparty/spirv-tools/source/opt/licm_pass.h b/3rdparty/spirv-tools/source/opt/licm_pass.h index 1d8ae2039..a17450043 100644 --- a/3rdparty/spirv-tools/source/opt/licm_pass.h +++ b/3rdparty/spirv-tools/source/opt/licm_pass.h @@ -15,12 +15,13 @@ #ifndef SOURCE_OPT_LICM_PASS_H_ #define SOURCE_OPT_LICM_PASS_H_ -#include "opt/basic_block.h" -#include "opt/instruction.h" -#include "opt/loop_descriptor.h" -#include "opt/pass.h" - #include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/instruction.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -30,7 +31,7 @@ class LICMPass : public Pass { LICMPass() {} const char* name() const override { return "loop-invariant-code-motion"; } - Status Process(ir::IRContext*) override; + Status Process() override; private: // Searches the IRContext for functions and processes each, moving invariants @@ -40,26 +41,24 @@ class LICMPass : public Pass { // Checks the function for loops, calling ProcessLoop on each one found. // Returns true if a change was made to the function, false otherwise. - bool ProcessFunction(ir::Function* f); + bool ProcessFunction(Function* f); // Checks for invariants in the loop and attempts to move them to the loops // preheader. Works from inner loop to outer when nested loops are found. // Returns true if a change was made to the loop, false otherwise. - bool ProcessLoop(ir::Loop* loop, ir::Function* f); + bool ProcessLoop(Loop* loop, Function* f); // Analyses each instruction in |bb|, hoisting invariants to |pre_header_bb|. // Each child of |bb| wrt to |dom_tree| is pushed to |loop_bbs| - bool AnalyseAndHoistFromBB(ir::Loop* loop, ir::Function* f, - ir::BasicBlock* bb, - std::vector* loop_bbs); + bool AnalyseAndHoistFromBB(Loop* loop, Function* f, BasicBlock* bb, + std::vector* loop_bbs); // Returns true if |bb| is immediately contained in |loop| - bool IsImmediatelyContainedInLoop(ir::Loop* loop, ir::Function* f, - ir::BasicBlock* bb); + bool IsImmediatelyContainedInLoop(Loop* loop, Function* f, BasicBlock* bb); // Move the instruction to the given BasicBlock // This method will update the instruction to block mapping for the context - void HoistInstruction(ir::Loop* loop, ir::Instruction* inst); + void HoistInstruction(Loop* loop, Instruction* inst); }; } // namespace opt diff --git a/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.cpp b/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.cpp index f87478ab5..5d00e98f7 100644 --- a/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.cpp @@ -14,8 +14,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "local_access_chain_convert_pass.h" +#include "source/opt/local_access_chain_convert_pass.h" +#include "ir_builder.h" #include "ir_context.h" #include "iterator.h" @@ -33,20 +34,20 @@ const uint32_t kTypeIntWidthInIdx = 0; void LocalAccessChainConvertPass::BuildAndAppendInst( SpvOp opcode, uint32_t typeId, uint32_t resultId, - const std::vector& in_opnds, - std::vector>* newInsts) { - std::unique_ptr newInst( - new ir::Instruction(context(), opcode, typeId, resultId, in_opnds)); + const std::vector& in_opnds, + std::vector>* newInsts) { + std::unique_ptr newInst( + new Instruction(context(), opcode, typeId, resultId, in_opnds)); get_def_use_mgr()->AnalyzeInstDefUse(&*newInst); newInsts->emplace_back(std::move(newInst)); } uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad( - const ir::Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId, - std::vector>* newInsts) { + const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId, + std::vector>* newInsts) { const uint32_t ldResultId = TakeNextId(); *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx); - const ir::Instruction* varInst = get_def_use_mgr()->GetDef(*varId); + const Instruction* varInst = get_def_use_mgr()->GetDef(*varId); assert(varInst->opcode() == SpvOpVariable); *varPteTypeId = GetPointeeTypeId(varInst); BuildAndAppendInst(SpvOpLoad, *varPteTypeId, ldResultId, @@ -56,11 +57,11 @@ uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad( } void LocalAccessChainConvertPass::AppendConstantOperands( - const ir::Instruction* ptrInst, std::vector* in_opnds) { + const Instruction* ptrInst, std::vector* in_opnds) { uint32_t iidIdx = 0; ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) { if (iidIdx > 0) { - const ir::Instruction* cInst = get_def_use_mgr()->GetDef(*iid); + const Instruction* cInst = get_def_use_mgr()->GetDef(*iid); uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx); in_opnds->push_back( {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}}); @@ -69,44 +70,56 @@ void LocalAccessChainConvertPass::AppendConstantOperands( }); } -uint32_t LocalAccessChainConvertPass::GenAccessChainLoadReplacement( - const ir::Instruction* ptrInst, - std::vector>* newInsts) { +void LocalAccessChainConvertPass::ReplaceAccessChainLoad( + const Instruction* address_inst, Instruction* original_load) { // Build and append load of variable in ptrInst + std::vector> new_inst; uint32_t varId; uint32_t varPteTypeId; const uint32_t ldResultId = - BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts); + BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst); + context()->get_decoration_mgr()->CloneDecorations( + original_load->result_id(), ldResultId, {SpvDecorationRelaxedPrecision}); + original_load->InsertBefore(std::move(new_inst)); - // Build and append Extract - const uint32_t extResultId = TakeNextId(); - const uint32_t ptrPteTypeId = GetPointeeTypeId(ptrInst); - std::vector ext_in_opnds = { - {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}}; - AppendConstantOperands(ptrInst, &ext_in_opnds); - BuildAndAppendInst(SpvOpCompositeExtract, ptrPteTypeId, extResultId, - ext_in_opnds, newInsts); - return extResultId; + // Rewrite |original_load| into an extract. + Instruction::OperandList new_operands; + + // copy the result id and the type id to the new operand list. + new_operands.emplace_back(original_load->GetOperand(0)); + new_operands.emplace_back(original_load->GetOperand(1)); + + new_operands.emplace_back( + Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}})); + AppendConstantOperands(address_inst, &new_operands); + original_load->SetOpcode(SpvOpCompositeExtract); + original_load->ReplaceOperands(new_operands); + context()->UpdateDefUse(original_load); } void LocalAccessChainConvertPass::GenAccessChainStoreReplacement( - const ir::Instruction* ptrInst, uint32_t valId, - std::vector>* newInsts) { + const Instruction* ptrInst, uint32_t valId, + std::vector>* newInsts) { // Build and append load of variable in ptrInst uint32_t varId; uint32_t varPteTypeId; const uint32_t ldResultId = BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts); + context()->get_decoration_mgr()->CloneDecorations( + varId, ldResultId, {SpvDecorationRelaxedPrecision}); // Build and append Insert const uint32_t insResultId = TakeNextId(); - std::vector ins_in_opnds = { + std::vector ins_in_opnds = { {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}}; AppendConstantOperands(ptrInst, &ins_in_opnds); BuildAndAppendInst(SpvOpCompositeInsert, varPteTypeId, insResultId, ins_in_opnds, newInsts); + context()->get_decoration_mgr()->CloneDecorations( + varId, insResultId, {SpvDecorationRelaxedPrecision}); + // Build and append Store BuildAndAppendInst(SpvOpStore, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}}, @@ -115,11 +128,11 @@ void LocalAccessChainConvertPass::GenAccessChainStoreReplacement( } bool LocalAccessChainConvertPass::IsConstantIndexAccessChain( - const ir::Instruction* acp) const { + const Instruction* acp) const { uint32_t inIdx = 0; return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) { if (inIdx > 0) { - ir::Instruction* opInst = get_def_use_mgr()->GetDef(*tid); + Instruction* opInst = get_def_use_mgr()->GetDef(*tid); if (opInst->opcode() != SpvOpConstant) return false; } ++inIdx; @@ -129,7 +142,7 @@ bool LocalAccessChainConvertPass::IsConstantIndexAccessChain( bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) { if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true; - if (get_def_use_mgr()->WhileEachUser(ptrId, [this](ir::Instruction* user) { + if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) { SpvOp op = user->opcode(); if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) { if (!HasOnlySupportedRefs(user->result_id())) { @@ -147,14 +160,14 @@ bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) { return false; } -void LocalAccessChainConvertPass::FindTargetVars(ir::Function* func) { +void LocalAccessChainConvertPass::FindTargetVars(Function* func) { for (auto bi = func->begin(); bi != func->end(); ++bi) { for (auto ii = bi->begin(); ii != bi->end(); ++ii) { switch (ii->opcode()) { case SpvOpStore: case SpvOpLoad: { uint32_t varId; - ir::Instruction* ptrInst = GetPtr(&*ii, &varId); + Instruction* ptrInst = GetPtr(&*ii, &varId); if (!IsTargetVar(varId)) break; const SpvOp op = ptrInst->opcode(); // Rule out variables with non-supported refs eg function calls @@ -185,36 +198,30 @@ void LocalAccessChainConvertPass::FindTargetVars(ir::Function* func) { } } -bool LocalAccessChainConvertPass::ConvertLocalAccessChains(ir::Function* func) { +bool LocalAccessChainConvertPass::ConvertLocalAccessChains(Function* func) { FindTargetVars(func); // Replace access chains of all targeted variables with equivalent // extract and insert sequences bool modified = false; for (auto bi = func->begin(); bi != func->end(); ++bi) { - std::vector dead_instructions; + std::vector dead_instructions; for (auto ii = bi->begin(); ii != bi->end(); ++ii) { switch (ii->opcode()) { case SpvOpLoad: { uint32_t varId; - ir::Instruction* ptrInst = GetPtr(&*ii, &varId); + Instruction* ptrInst = GetPtr(&*ii, &varId); if (!IsNonPtrAccessChain(ptrInst->opcode())) break; if (!IsTargetVar(varId)) break; - std::vector> newInsts; - uint32_t replId = GenAccessChainLoadReplacement(ptrInst, &newInsts); - context()->KillNamesAndDecorates(&*ii); - context()->ReplaceAllUsesWith(ii->result_id(), replId); - dead_instructions.push_back(&*ii); - ++ii; - ii = ii.InsertBefore(std::move(newInsts)); - ++ii; + std::vector> newInsts; + ReplaceAccessChainLoad(ptrInst, &*ii); modified = true; } break; case SpvOpStore: { uint32_t varId; - ir::Instruction* ptrInst = GetPtr(&*ii, &varId); + Instruction* ptrInst = GetPtr(&*ii, &varId); if (!IsNonPtrAccessChain(ptrInst->opcode())) break; if (!IsTargetVar(varId)) break; - std::vector> newInsts; + std::vector> newInsts; uint32_t valId = ii->GetSingleWordInOperand(kStoreValIdInIdx); GenAccessChainStoreReplacement(ptrInst, valId, &newInsts); dead_instructions.push_back(&*ii); @@ -230,9 +237,9 @@ bool LocalAccessChainConvertPass::ConvertLocalAccessChains(ir::Function* func) { } while (!dead_instructions.empty()) { - ir::Instruction* inst = dead_instructions.back(); + Instruction* inst = dead_instructions.back(); dead_instructions.pop_back(); - DCEInst(inst, [&dead_instructions](ir::Instruction* other_inst) { + DCEInst(inst, [&dead_instructions](Instruction* other_inst) { auto i = std::find(dead_instructions.begin(), dead_instructions.end(), other_inst); if (i != dead_instructions.end()) { @@ -244,9 +251,7 @@ bool LocalAccessChainConvertPass::ConvertLocalAccessChains(ir::Function* func) { return modified; } -void LocalAccessChainConvertPass::Initialize(ir::IRContext* c) { - InitializeProcessing(c); - +void LocalAccessChainConvertPass::Initialize() { // Initialize Target Variable Caches seen_target_vars_.clear(); seen_non_target_vars_.clear(); @@ -272,7 +277,7 @@ bool LocalAccessChainConvertPass::AllExtensionsSupported() const { Pass::Status LocalAccessChainConvertPass::ProcessImpl() { // If non-32-bit integer type in module, terminate processing // TODO(): Handle non-32-bit integer constants in access chains - for (const ir::Instruction& inst : get_module()->types_values()) + for (const Instruction& inst : get_module()->types_values()) if (inst.opcode() == SpvOpTypeInt && inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32) return Status::SuccessWithoutChange; @@ -284,7 +289,7 @@ Pass::Status LocalAccessChainConvertPass::ProcessImpl() { // Do not process if any disallowed extensions are enabled if (!AllExtensionsSupported()) return Status::SuccessWithoutChange; // Process all entry point functions. - ProcessFunction pfn = [this](ir::Function* fp) { + ProcessFunction pfn = [this](Function* fp) { return ConvertLocalAccessChains(fp); }; bool modified = ProcessEntryPointCallTree(pfn, get_module()); @@ -293,8 +298,8 @@ Pass::Status LocalAccessChainConvertPass::ProcessImpl() { LocalAccessChainConvertPass::LocalAccessChainConvertPass() {} -Pass::Status LocalAccessChainConvertPass::Process(ir::IRContext* c) { - Initialize(c); +Pass::Status LocalAccessChainConvertPass::Process() { + Initialize(); return ProcessImpl(); } diff --git a/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.h b/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.h index 98f009a89..9d06890bf 100644 --- a/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.h +++ b/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.h @@ -14,20 +14,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_ -#define LIBSPIRV_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_ +#ifndef SOURCE_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_ +#define SOURCE_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_ #include #include +#include #include +#include #include #include #include +#include -#include "basic_block.h" -#include "def_use_manager.h" -#include "mem_pass.h" -#include "module.h" +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { @@ -36,14 +39,15 @@ namespace opt { class LocalAccessChainConvertPass : public MemPass { public: LocalAccessChainConvertPass(); - const char* name() const override { return "convert-local-access-chains"; } - Status Process(ir::IRContext* c) override; - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse; + const char* name() const override { return "convert-local-access-chains"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse; } - using ProcessFunction = std::function; + using ProcessFunction = std::function; private: // Return true if all refs through |ptrId| are only loads or stores and @@ -55,42 +59,42 @@ class LocalAccessChainConvertPass : public MemPass { // Search |func| and cache function scope variables of target type that are // not accessed with non-constant-index access chains. Also cache non-target // variables. - void FindTargetVars(ir::Function* func); + void FindTargetVars(Function* func); // Build instruction from |opcode|, |typeId|, |resultId|, and |in_opnds|. // Append to |newInsts|. - void BuildAndAppendInst( - SpvOp opcode, uint32_t typeId, uint32_t resultId, - const std::vector& in_opnds, - std::vector>* newInsts); + void BuildAndAppendInst(SpvOp opcode, uint32_t typeId, uint32_t resultId, + const std::vector& in_opnds, + std::vector>* newInsts); // Build load of variable in |ptrInst| and append to |newInsts|. // Return var in |varId| and its pointee type in |varPteTypeId|. uint32_t BuildAndAppendVarLoad( - const ir::Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId, - std::vector>* newInsts); + const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId, + std::vector>* newInsts); // Append literal integer operands to |in_opnds| corresponding to constant // integer operands from access chain |ptrInst|. Assumes all indices in // access chains are OpConstant. - void AppendConstantOperands(const ir::Instruction* ptrInst, - std::vector* in_opnds); + void AppendConstantOperands(const Instruction* ptrInst, + std::vector* in_opnds); // Create a load/insert/store equivalent to a store of // |valId| through (constant index) access chaing |ptrInst|. // Append to |newInsts|. void GenAccessChainStoreReplacement( - const ir::Instruction* ptrInst, uint32_t valId, - std::vector>* newInsts); + const Instruction* ptrInst, uint32_t valId, + std::vector>* newInsts); - // For the (constant index) access chain |ptrInst|, create an - // equivalent load and extract. Append to |newInsts|. - uint32_t GenAccessChainLoadReplacement( - const ir::Instruction* ptrInst, - std::vector>* newInsts); + // For the (constant index) access chain |address_inst|, create an + // equivalent load and extract that replaces |original_load|. The result id + // of the extract will be the same as the original result id of + // |original_load|. + void ReplaceAccessChainLoad(const Instruction* address_inst, + Instruction* original_load); // Return true if all indices of access chain |acp| are OpConstant integers - bool IsConstantIndexAccessChain(const ir::Instruction* acp) const; + bool IsConstantIndexAccessChain(const Instruction* acp) const; // Identify all function scope variables of target type which are // accessed only with loads, stores and access chains with constant @@ -101,7 +105,7 @@ class LocalAccessChainConvertPass : public MemPass { // // Nested access chains and pointer access chains are not currently // converted. - bool ConvertLocalAccessChains(ir::Function* func); + bool ConvertLocalAccessChains(Function* func); // Initialize extensions whitelist void InitExtensions(); @@ -109,7 +113,7 @@ class LocalAccessChainConvertPass : public MemPass { // Return true if all extensions in this module are allowed by this pass. bool AllExtensionsSupported() const; - void Initialize(ir::IRContext* c); + void Initialize(); Pass::Status ProcessImpl(); // Variables with only supported references, ie. loads and stores using @@ -123,4 +127,4 @@ class LocalAccessChainConvertPass : public MemPass { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_ +#endif // SOURCE_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/local_redundancy_elimination.cpp b/3rdparty/spirv-tools/source/opt/local_redundancy_elimination.cpp index d6fb48caf..9539e6556 100644 --- a/3rdparty/spirv-tools/source/opt/local_redundancy_elimination.cpp +++ b/3rdparty/spirv-tools/source/opt/local_redundancy_elimination.cpp @@ -12,16 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "local_redundancy_elimination.h" +#include "source/opt/local_redundancy_elimination.h" -#include "value_number_table.h" +#include "source/opt/value_number_table.h" namespace spvtools { namespace opt { -Pass::Status LocalRedundancyEliminationPass::Process(ir::IRContext* c) { - InitializeProcessing(c); - +Pass::Status LocalRedundancyEliminationPass::Process() { bool modified = false; ValueNumberTable vnTable(context()); @@ -39,11 +37,11 @@ Pass::Status LocalRedundancyEliminationPass::Process(ir::IRContext* c) { } bool LocalRedundancyEliminationPass::EliminateRedundanciesInBB( - ir::BasicBlock* block, const ValueNumberTable& vnTable, + BasicBlock* block, const ValueNumberTable& vnTable, std::map* value_to_ids) { bool modified = false; - auto func = [this, &vnTable, &modified, value_to_ids](ir::Instruction* inst) { + auto func = [this, &vnTable, &modified, value_to_ids](Instruction* inst) { if (inst->result_id() == 0) { return; } diff --git a/3rdparty/spirv-tools/source/opt/local_redundancy_elimination.h b/3rdparty/spirv-tools/source/opt/local_redundancy_elimination.h index cc83b6061..9f55c8bfe 100644 --- a/3rdparty/spirv-tools/source/opt/local_redundancy_elimination.h +++ b/3rdparty/spirv-tools/source/opt/local_redundancy_elimination.h @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_LOCAL_REDUNDANCY_ELIMINATION_H_ -#define LIBSPIRV_OPT_LOCAL_REDUNDANCY_ELIMINATION_H_ +#ifndef SOURCE_OPT_LOCAL_REDUNDANCY_ELIMINATION_H_ +#define SOURCE_OPT_LOCAL_REDUNDANCY_ELIMINATION_H_ -#include "ir_context.h" -#include "pass.h" -#include "value_number_table.h" +#include + +#include "source/opt/ir_context.h" +#include "source/opt/pass.h" +#include "source/opt/value_number_table.h" namespace spvtools { namespace opt { @@ -32,14 +34,14 @@ namespace opt { class LocalRedundancyEliminationPass : public Pass { public: const char* name() const override { return "local-redundancy-elimination"; } - Status Process(ir::IRContext*) override; - virtual ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping | - ir::IRContext::kAnalysisDecorations | - ir::IRContext::kAnalysisCombinators | ir::IRContext::kAnalysisCFG | - ir::IRContext::kAnalysisDominatorAnalysis | - ir::IRContext::kAnalysisNameMap; + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap; } protected: @@ -54,7 +56,7 @@ class LocalRedundancyEliminationPass : public Pass { // dominates |bb|. // // Returns true if the module is changed. - bool EliminateRedundanciesInBB(ir::BasicBlock* block, + bool EliminateRedundanciesInBB(BasicBlock* block, const ValueNumberTable& vnTable, std::map* value_to_ids); }; @@ -62,4 +64,4 @@ class LocalRedundancyEliminationPass : public Pass { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_LOCAL_REDUNDANCY_ELIMINATION_H_ +#endif // SOURCE_OPT_LOCAL_REDUNDANCY_ELIMINATION_H_ diff --git a/3rdparty/spirv-tools/source/opt/local_single_block_elim_pass.cpp b/3rdparty/spirv-tools/source/opt/local_single_block_elim_pass.cpp index ae4c4977c..bb909f4aa 100644 --- a/3rdparty/spirv-tools/source/opt/local_single_block_elim_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/local_single_block_elim_pass.cpp @@ -14,13 +14,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "local_single_block_elim_pass.h" +#include "source/opt/local_single_block_elim_pass.h" -#include "iterator.h" +#include + +#include "source/opt/iterator.h" namespace spvtools { namespace opt { - namespace { const uint32_t kStoreValIdInIdx = 1; @@ -29,7 +30,7 @@ const uint32_t kStoreValIdInIdx = 1; bool LocalSingleBlockLoadStoreElimPass::HasOnlySupportedRefs(uint32_t ptrId) { if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true; - if (get_def_use_mgr()->WhileEachUser(ptrId, [this](ir::Instruction* user) { + if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) { SpvOp op = user->opcode(); if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) { if (!HasOnlySupportedRefs(user->result_id())) { @@ -48,13 +49,15 @@ bool LocalSingleBlockLoadStoreElimPass::HasOnlySupportedRefs(uint32_t ptrId) { } bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim( - ir::Function* func) { - // Perform local store/load and load/load elimination on each block + Function* func) { + // Perform local store/load, load/load and store/store elimination + // on each block bool modified = false; + std::vector instructions_to_kill; + std::unordered_set instructions_to_save; for (auto bi = func->begin(); bi != func->end(); ++bi) { var2store_.clear(); var2load_.clear(); - pinned_vars_.clear(); auto next = bi->begin(); for (auto ii = next; ii != bi->end(); ii = next) { ++next; @@ -62,34 +65,56 @@ bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim( case SpvOpStore: { // Verify store variable is target type uint32_t varId; - ir::Instruction* ptrInst = GetPtr(&*ii, &varId); + Instruction* ptrInst = GetPtr(&*ii, &varId); if (!IsTargetVar(varId)) continue; if (!HasOnlySupportedRefs(varId)) continue; - // Register the store + // If a store to the whole variable, remember it for succeeding + // loads and stores. Otherwise forget any previous store to that + // variable. if (ptrInst->opcode() == SpvOpVariable) { - // if not pinned, look for WAW - if (pinned_vars_.find(varId) == pinned_vars_.end()) { - auto si = var2store_.find(varId); - if (si != var2store_.end()) { + // If a previous store to same variable, mark the store + // for deletion if not still used. + auto prev_store = var2store_.find(varId); + if (prev_store != var2store_.end() && + instructions_to_save.count(prev_store->second) == 0) { + instructions_to_kill.push_back(prev_store->second); + modified = true; + } + + bool kill_store = false; + auto li = var2load_.find(varId); + if (li != var2load_.end()) { + if (ii->GetSingleWordInOperand(kStoreValIdInIdx) == + li->second->result_id()) { + // We are storing the same value that already exists in the + // memory location. The store does nothing. + kill_store = true; } } - var2store_[varId] = &*ii; + + if (!kill_store) { + var2store_[varId] = &*ii; + var2load_.erase(varId); + } else { + instructions_to_kill.push_back(&*ii); + modified = true; + } } else { assert(IsNonPtrAccessChain(ptrInst->opcode())); var2store_.erase(varId); + var2load_.erase(varId); } - pinned_vars_.erase(varId); - var2load_.erase(varId); } break; case SpvOpLoad: { // Verify store variable is target type uint32_t varId; - ir::Instruction* ptrInst = GetPtr(&*ii, &varId); + Instruction* ptrInst = GetPtr(&*ii, &varId); if (!IsTargetVar(varId)) continue; if (!HasOnlySupportedRefs(varId)) continue; - // Look for previous store or load uint32_t replId = 0; if (ptrInst->opcode() == SpvOpVariable) { + // If a load from a variable, look for a previous store or + // load from that variable and use its value. auto si = var2store_.find(varId); if (si != var2store_.end()) { replId = si->second->GetSingleWordInOperand(kStoreValIdInIdx); @@ -99,16 +124,21 @@ bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim( replId = li->second->result_id(); } } + } else { + // If a partial load of a previously seen store, remember + // not to delete the store. + auto si = var2store_.find(varId); + if (si != var2store_.end()) instructions_to_save.insert(si->second); } if (replId != 0) { // replace load's result id and delete load context()->KillNamesAndDecorates(&*ii); context()->ReplaceAllUsesWith(ii->result_id(), replId); + instructions_to_kill.push_back(&*ii); modified = true; } else { if (ptrInst->opcode() == SpvOpVariable) var2load_[varId] = &*ii; // register load - pinned_vars_.insert(varId); } } break; case SpvOpFunctionCall: { @@ -116,19 +146,21 @@ bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim( // TODO(): Handle more optimally var2store_.clear(); var2load_.clear(); - pinned_vars_.clear(); } break; default: break; } } } + + for (Instruction* inst : instructions_to_kill) { + context()->KillInst(inst); + } + return modified; } -void LocalSingleBlockLoadStoreElimPass::Initialize(ir::IRContext* c) { - InitializeProcessing(c); - +void LocalSingleBlockLoadStoreElimPass::Initialize() { // Initialize Target Type Caches seen_target_vars_.clear(); seen_non_target_vars_.clear(); @@ -164,17 +196,19 @@ Pass::Status LocalSingleBlockLoadStoreElimPass::ProcessImpl() { // return unmodified. if (!AllExtensionsSupported()) return Status::SuccessWithoutChange; // Process all entry point functions - ProcessFunction pfn = [this](ir::Function* fp) { + ProcessFunction pfn = [this](Function* fp) { return LocalSingleBlockLoadStoreElim(fp); }; + bool modified = ProcessEntryPointCallTree(pfn, get_module()); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } -LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElimPass() {} +LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElimPass() = + default; -Pass::Status LocalSingleBlockLoadStoreElimPass::Process(ir::IRContext* c) { - Initialize(c); +Pass::Status LocalSingleBlockLoadStoreElimPass::Process() { + Initialize(); return ProcessImpl(); } diff --git a/3rdparty/spirv-tools/source/opt/local_single_block_elim_pass.h b/3rdparty/spirv-tools/source/opt/local_single_block_elim_pass.h index fa68788c1..3dead9834 100644 --- a/3rdparty/spirv-tools/source/opt/local_single_block_elim_pass.h +++ b/3rdparty/spirv-tools/source/opt/local_single_block_elim_pass.h @@ -14,20 +14,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_LOCAL_SINGLE_BLOCK_ELIM_PASS_H_ -#define LIBSPIRV_OPT_LOCAL_SINGLE_BLOCK_ELIM_PASS_H_ +#ifndef SOURCE_OPT_LOCAL_SINGLE_BLOCK_ELIM_PASS_H_ +#define SOURCE_OPT_LOCAL_SINGLE_BLOCK_ELIM_PASS_H_ #include #include #include +#include #include #include #include -#include "basic_block.h" -#include "def_use_manager.h" -#include "mem_pass.h" -#include "module.h" +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { @@ -36,11 +37,12 @@ namespace opt { class LocalSingleBlockLoadStoreElimPass : public MemPass { public: LocalSingleBlockLoadStoreElimPass(); - const char* name() const override { return "eliminate-local-single-block"; } - Status Process(ir::IRContext* c) override; - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse; + const char* name() const override { return "eliminate-local-single-block"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping; } private: @@ -57,7 +59,7 @@ class LocalSingleBlockLoadStoreElimPass : public MemPass { // load id with previous id and delete load. Finally, check if // remaining stores are useless, and delete store and variable // where possible. Assumes logical addressing. - bool LocalSingleBlockLoadStoreElim(ir::Function* func); + bool LocalSingleBlockLoadStoreElim(Function* func); // Initialize extensions whitelist void InitExtensions(); @@ -65,7 +67,7 @@ class LocalSingleBlockLoadStoreElimPass : public MemPass { // Return true if all extensions in this module are supported by this pass. bool AllExtensionsSupported() const; - void Initialize(ir::IRContext* c); + void Initialize(); Pass::Status ProcessImpl(); // Map from function scope variable to a store of that variable in the @@ -73,14 +75,14 @@ class LocalSingleBlockLoadStoreElimPass : public MemPass { // at the start of each block and incrementally updated as the block // is scanned. The stores are candidates for elimination. The map is // conservatively cleared when a function call is encountered. - std::unordered_map var2store_; + std::unordered_map var2store_; // Map from function scope variable to a load of that variable in the // current block whose value is currently valid. This map is cleared // at the start of each block and incrementally updated as the block // is scanned. The stores are candidates for elimination. The map is // conservatively cleared when a function call is encountered. - std::unordered_map var2load_; + std::unordered_map var2load_; // Set of variables whose most recent store in the current block cannot be // deleted, for example, if there is a load of the variable which is @@ -100,4 +102,4 @@ class LocalSingleBlockLoadStoreElimPass : public MemPass { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_LOCAL_SINGLE_BLOCK_ELIM_PASS_H_ +#endif // SOURCE_OPT_LOCAL_SINGLE_BLOCK_ELIM_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/local_single_store_elim_pass.cpp b/3rdparty/spirv-tools/source/opt/local_single_store_elim_pass.cpp index 405fba3c0..4c837fc73 100644 --- a/3rdparty/spirv-tools/source/opt/local_single_store_elim_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/local_single_store_elim_pass.cpp @@ -14,11 +14,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "local_single_store_elim_pass.h" +#include "source/opt/local_single_store_elim_pass.h" -#include "cfa.h" -#include "iterator.h" -#include "latest_version_glsl_std_450_header.h" +#include "source/cfa.h" +#include "source/latest_version_glsl_std_450_header.h" +#include "source/opt/iterator.h" namespace spvtools { namespace opt { @@ -30,220 +30,21 @@ const uint32_t kVariableInitIdInIdx = 1; } // anonymous namespace -bool LocalSingleStoreElimPass::HasOnlySupportedRefs(uint32_t ptrId) { - if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true; - if (get_def_use_mgr()->WhileEachUser(ptrId, [this](ir::Instruction* user) { - SpvOp op = user->opcode(); - if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) { - if (!HasOnlySupportedRefs(user->result_id())) { - return false; - } - } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName && - !IsNonTypeDecorate(op)) { - return false; - } - return true; - })) { - supported_ref_ptrs_.insert(ptrId); - return true; - } - return false; -} - -void LocalSingleStoreElimPass::SingleStoreAnalyze(ir::Function* func) { - ssa_var2store_.clear(); - non_ssa_vars_.clear(); - store2idx_.clear(); - store2blk_.clear(); - for (auto bi = func->begin(); bi != func->end(); ++bi) { - uint32_t instIdx = 0; - for (auto ii = bi->begin(); ii != bi->end(); ++ii, ++instIdx) { - uint32_t varId = 0; - ir::Instruction* ptrInst = nullptr; - switch (ii->opcode()) { - case SpvOpStore: { - ptrInst = GetPtr(&*ii, &varId); - } break; - case SpvOpVariable: { - // If initializer, treat like store - if (ii->NumInOperands() > 1) { - varId = ii->result_id(); - ptrInst = &*ii; - } - } break; - default: - break; - } // switch - if (varId == 0) continue; - // Verify variable is target type - if (non_ssa_vars_.find(varId) != non_ssa_vars_.end()) continue; - if (ptrInst->opcode() != SpvOpVariable) { - non_ssa_vars_.insert(varId); - ssa_var2store_.erase(varId); - continue; - } - // Verify target type and function storage class - if (!IsTargetVar(varId)) { - non_ssa_vars_.insert(varId); - continue; - } - if (!HasOnlySupportedRefs(varId)) { - non_ssa_vars_.insert(varId); - continue; - } - // Ignore variables with multiple stores - if (ssa_var2store_.find(varId) != ssa_var2store_.end()) { - non_ssa_vars_.insert(varId); - ssa_var2store_.erase(varId); - continue; - } - // Remember pointer to variable's store and it's - // ordinal position in block - ssa_var2store_[varId] = &*ii; - store2idx_[&*ii] = instIdx; - store2blk_[&*ii] = &*bi; - } - } -} - -LocalSingleStoreElimPass::GetBlocksFunction -LocalSingleStoreElimPass::AugmentedCFGSuccessorsFunction() const { - return [this](const ir::BasicBlock* block) { - auto asmi = augmented_successors_map_.find(block); - if (asmi != augmented_successors_map_.end()) return &(*asmi).second; - auto smi = successors_map_.find(block); - return &(*smi).second; - }; -} - -LocalSingleStoreElimPass::GetBlocksFunction -LocalSingleStoreElimPass::AugmentedCFGPredecessorsFunction() const { - return [this](const ir::BasicBlock* block) { - auto apmi = augmented_predecessors_map_.find(block); - if (apmi != augmented_predecessors_map_.end()) return &(*apmi).second; - auto pmi = predecessors_map_.find(block); - return &(*pmi).second; - }; -} - -void LocalSingleStoreElimPass::CalculateImmediateDominators( - ir::Function* func) { - // Compute CFG - vector ordered_blocks; - predecessors_map_.clear(); - successors_map_.clear(); - for (auto& blk : *func) { - ordered_blocks.push_back(&blk); - const auto& const_blk = blk; - const_blk.ForEachSuccessorLabel([&blk, this](const uint32_t sbid) { - successors_map_[&blk].push_back(label2block_[sbid]); - predecessors_map_[label2block_[sbid]].push_back(&blk); - }); - } - // Compute Augmented CFG - augmented_successors_map_.clear(); - augmented_predecessors_map_.clear(); - successors_map_[cfg()->pseudo_exit_block()] = {}; - predecessors_map_[cfg()->pseudo_entry_block()] = {}; - auto succ_func = [this](const ir::BasicBlock* b) { - return &successors_map_[b]; - }; - auto pred_func = [this](const ir::BasicBlock* b) { - return &predecessors_map_[b]; - }; - CFA::ComputeAugmentedCFG( - ordered_blocks, cfg()->pseudo_entry_block(), cfg()->pseudo_exit_block(), - &augmented_successors_map_, &augmented_predecessors_map_, succ_func, - pred_func); - // Compute Dominators - vector postorder; - auto ignore_block = [](cbb_ptr) {}; - auto ignore_edge = [](cbb_ptr, cbb_ptr) {}; - spvtools::CFA::DepthFirstTraversal( - ordered_blocks[0], AugmentedCFGSuccessorsFunction(), ignore_block, - [&](cbb_ptr b) { postorder.push_back(b); }, ignore_edge); - auto edges = spvtools::CFA::CalculateDominators( - postorder, AugmentedCFGPredecessorsFunction()); - idom_.clear(); - for (auto edge : edges) idom_[edge.first] = edge.second; -} - -bool LocalSingleStoreElimPass::Dominates(ir::BasicBlock* blk0, uint32_t idx0, - ir::BasicBlock* blk1, uint32_t idx1) { - if (blk0 == blk1) return idx0 <= idx1; - ir::BasicBlock* b = blk1; - while (idom_[b] != b) { - b = idom_[b]; - if (b == blk0) return true; - } - return false; -} - -bool LocalSingleStoreElimPass::SingleStoreProcess(ir::Function* func) { - CalculateImmediateDominators(func); +bool LocalSingleStoreElimPass::LocalSingleStoreElim(Function* func) { bool modified = false; - for (auto bi = func->begin(); bi != func->end(); ++bi) { - uint32_t instIdx = 0; - for (auto ii = bi->begin(); ii != bi->end(); ++ii, ++instIdx) { - if (ii->opcode() != SpvOpLoad) continue; - uint32_t varId; - ir::Instruction* ptrInst = GetPtr(&*ii, &varId); - // Skip access chain loads - if (ptrInst->opcode() != SpvOpVariable) continue; - const auto vsi = ssa_var2store_.find(varId); - if (vsi == ssa_var2store_.end()) continue; - if (non_ssa_vars_.find(varId) != non_ssa_vars_.end()) continue; - // store must dominate load - if (!Dominates(store2blk_[vsi->second], store2idx_[vsi->second], &*bi, - instIdx)) - continue; - // Determine replacement id depending on OpStore or OpVariable - uint32_t replId; - if (vsi->second->opcode() == SpvOpStore) - replId = vsi->second->GetSingleWordInOperand(kStoreValIdInIdx); - else - replId = vsi->second->GetSingleWordInOperand(kVariableInitIdInIdx); - // Replace all instances of the load's id with the SSA value's id - // and add load to removal list - context()->KillNamesAndDecorates(&*ii); - context()->ReplaceAllUsesWith(ii->result_id(), replId); - modified = true; + + // Check all function scope variables in |func|. + BasicBlock* entry_block = &*func->begin(); + for (Instruction& inst : *entry_block) { + if (inst.opcode() != SpvOpVariable) { + break; } + + modified |= ProcessVariable(&inst); } return modified; } -bool LocalSingleStoreElimPass::LocalSingleStoreElim(ir::Function* func) { - bool modified = false; - SingleStoreAnalyze(func); - if (ssa_var2store_.empty()) return false; - modified |= SingleStoreProcess(func); - return modified; -} - -void LocalSingleStoreElimPass::Initialize(ir::IRContext* irContext) { - InitializeProcessing(irContext); - - // Initialize function and block maps - label2block_.clear(); - for (auto& fn : *get_module()) { - for (auto& blk : fn) { - uint32_t bid = blk.id(); - label2block_[bid] = &blk; - } - } - - // Initialize Target Type Caches - seen_target_vars_.clear(); - seen_non_target_vars_.clear(); - - // Initialize Supported Ref Pointer Cache - supported_ref_ptrs_.clear(); - - // Initialize extension whitelist - InitExtensions(); -} - bool LocalSingleStoreElimPass::AllExtensionsSupported() const { // If any extension not in whitelist, return false for (auto& ei : get_module()->extensions()) { @@ -259,30 +60,25 @@ Pass::Status LocalSingleStoreElimPass::ProcessImpl() { // Assumes relaxed logical addressing only (see instruction.h) if (context()->get_feature_mgr()->HasCapability(SpvCapabilityAddresses)) return Status::SuccessWithoutChange; - // Do not process if module contains OpGroupDecorate. Additional - // support required in KillNamesAndDecorates(). - // TODO(greg-lunarg): Add support for OpGroupDecorate - for (auto& ai : get_module()->annotations()) - if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange; + // Do not process if any disallowed extensions are enabled if (!AllExtensionsSupported()) return Status::SuccessWithoutChange; // Process all entry point functions - ProcessFunction pfn = [this](ir::Function* fp) { + ProcessFunction pfn = [this](Function* fp) { return LocalSingleStoreElim(fp); }; bool modified = ProcessEntryPointCallTree(pfn, get_module()); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } -LocalSingleStoreElimPass::LocalSingleStoreElimPass() {} +LocalSingleStoreElimPass::LocalSingleStoreElimPass() = default; -Pass::Status LocalSingleStoreElimPass::Process(ir::IRContext* irContext) { - Initialize(irContext); +Pass::Status LocalSingleStoreElimPass::Process() { + InitExtensionWhiteList(); return ProcessImpl(); } -void LocalSingleStoreElimPass::InitExtensions() { - extensions_whitelist_.clear(); +void LocalSingleStoreElimPass::InitExtensionWhiteList() { extensions_whitelist_.insert({ "SPV_AMD_shader_explicit_vertex_parameter", "SPV_AMD_shader_trinary_minmax", @@ -319,6 +115,127 @@ void LocalSingleStoreElimPass::InitExtensions() { "SPV_EXT_descriptor_indexing", }); } +bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) { + std::vector users; + FindUses(var_inst, &users); + + Instruction* store_inst = FindSingleStoreAndCheckUses(var_inst, users); + + if (store_inst == nullptr) { + return false; + } + + return RewriteLoads(store_inst, users); +} + +Instruction* LocalSingleStoreElimPass::FindSingleStoreAndCheckUses( + Instruction* var_inst, const std::vector& users) const { + // Make sure there is exactly 1 store. + Instruction* store_inst = nullptr; + + // If |var_inst| has an initializer, then that will count as a store. + if (var_inst->NumInOperands() > 1) { + store_inst = var_inst; + } + + for (Instruction* user : users) { + switch (user->opcode()) { + case SpvOpStore: + // Since we are in the relaxed addressing mode, the use has to be the + // base address of the store, and not the value being store. Otherwise, + // we would have a pointer to a pointer to function scope memory, which + // is not allowed. + if (store_inst == nullptr) { + store_inst = user; + } else { + // More than 1 store. + return nullptr; + } + break; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + if (FeedsAStore(user)) { + // Has a partial store. Cannot propagate that. + return nullptr; + } + break; + case SpvOpLoad: + case SpvOpImageTexelPointer: + case SpvOpName: + case SpvOpCopyObject: + break; + default: + if (!user->IsDecoration()) { + // Don't know if this instruction modifies the variable. + // Conservatively assume it is a store. + return nullptr; + } + break; + } + } + return store_inst; +} + +void LocalSingleStoreElimPass::FindUses( + const Instruction* var_inst, std::vector* users) const { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + def_use_mgr->ForEachUser(var_inst, [users, this](Instruction* user) { + users->push_back(user); + if (user->opcode() == SpvOpCopyObject) { + FindUses(user, users); + } + }); +} + +bool LocalSingleStoreElimPass::FeedsAStore(Instruction* inst) const { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + return !def_use_mgr->WhileEachUser(inst, [this](Instruction* user) { + switch (user->opcode()) { + case SpvOpStore: + return false; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpCopyObject: + return !FeedsAStore(user); + case SpvOpLoad: + case SpvOpImageTexelPointer: + case SpvOpName: + return true; + default: + // Don't know if this instruction modifies the variable. + // Conservatively assume it is a store. + return user->IsDecoration(); + } + }); +} + +bool LocalSingleStoreElimPass::RewriteLoads( + Instruction* store_inst, const std::vector& uses) { + BasicBlock* store_block = context()->get_instr_block(store_inst); + DominatorAnalysis* dominator_analysis = + context()->GetDominatorAnalysis(store_block->GetParent()); + + uint32_t stored_id; + if (store_inst->opcode() == SpvOpStore) + stored_id = store_inst->GetSingleWordInOperand(kStoreValIdInIdx); + else + stored_id = store_inst->GetSingleWordInOperand(kVariableInitIdInIdx); + + std::vector uses_in_store_block; + bool modified = false; + for (Instruction* use : uses) { + if (use->opcode() == SpvOpLoad) { + if (dominator_analysis->Dominates(store_inst, use)) { + modified = true; + context()->KillNamesAndDecorates(use->result_id()); + context()->ReplaceAllUsesWith(use->result_id(), stored_id); + context()->KillInst(use); + } + } + } + + return modified; +} } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/local_single_store_elim_pass.h b/3rdparty/spirv-tools/source/opt/local_single_store_elim_pass.h index 47b80049d..d3d64b829 100644 --- a/3rdparty/spirv-tools/source/opt/local_single_store_elim_pass.h +++ b/3rdparty/spirv-tools/source/opt/local_single_store_elim_pass.h @@ -14,130 +14,82 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_LOCAL_SINGLE_STORE_ELIM_PASS_H_ -#define LIBSPIRV_OPT_LOCAL_SINGLE_STORE_ELIM_PASS_H_ +#ifndef SOURCE_OPT_LOCAL_SINGLE_STORE_ELIM_PASS_H_ +#define SOURCE_OPT_LOCAL_SINGLE_STORE_ELIM_PASS_H_ #include #include #include +#include #include #include #include +#include -#include "basic_block.h" -#include "def_use_manager.h" -#include "mem_pass.h" -#include "module.h" +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { // See optimizer.hpp for documentation. -class LocalSingleStoreElimPass : public MemPass { - using cbb_ptr = const ir::BasicBlock*; +class LocalSingleStoreElimPass : public Pass { + using cbb_ptr = const BasicBlock*; public: LocalSingleStoreElimPass(); - const char* name() const override { return "eliminate-local-single-store"; } - Status Process(ir::IRContext* irContext) override; - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse; + const char* name() const override { return "eliminate-local-single-store"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping; } private: - // Return true if all refs through |ptrId| are only loads or stores and - // cache ptrId in supported_ref_ptrs_. TODO(dnovillo): This function is - // replicated in other passes and it's slightly different in every pass. Is it - // possible to make one common implementation? - bool HasOnlySupportedRefs(uint32_t ptrId); - - // Find all function scope variables in |func| that are stored to - // only once (SSA) and map to their stored value id. Only analyze - // variables of scalar, vector, matrix types and struct and array - // types comprising only these types. Currently this analysis is - // is not done in the presence of function calls. TODO(): Allow - // analysis in the presence of function calls. - void SingleStoreAnalyze(ir::Function* func); - - using GetBlocksFunction = - std::function*(const ir::BasicBlock*)>; - - /// Returns the block successors function for the augmented CFG. - GetBlocksFunction AugmentedCFGSuccessorsFunction() const; - - /// Returns the block predecessors function for the augmented CFG. - GetBlocksFunction AugmentedCFGPredecessorsFunction() const; - - // Calculate immediate dominators for |func|'s CFG. Leaves result - // in idom_. Entries for augmented CFG (pseudo blocks) are not created. - // TODO(dnovillo): Move to new CFG class. - void CalculateImmediateDominators(ir::Function* func); - - // Return true if instruction in |blk0| at ordinal position |idx0| - // dominates instruction in |blk1| at position |idx1|. - bool Dominates(ir::BasicBlock* blk0, uint32_t idx0, ir::BasicBlock* blk1, - uint32_t idx1); - - // For each load of an SSA variable in |func|, replace all uses of - // the load with the value stored if the store dominates the load. - // Assumes that SingleStoreAnalyze() has just been run. Return true - // if any instructions are modified. - bool SingleStoreProcess(ir::Function* func); - // Do "single-store" optimization of function variables defined only // with a single non-access-chain store in |func|. Replace all their // non-access-chain loads with the value that is stored and eliminate // any resulting dead code. - bool LocalSingleStoreElim(ir::Function* func); + bool LocalSingleStoreElim(Function* func); // Initialize extensions whitelist - void InitExtensions(); + void InitExtensionWhiteList(); // Return true if all extensions in this module are allowed by this pass. bool AllExtensionsSupported() const; - void Initialize(ir::IRContext* irContext); Pass::Status ProcessImpl(); - // Map from block's label id to block - std::unordered_map label2block_; + // If there is a single store to |var_inst|, and it covers the entire + // variable, then replace all of the loads of the entire variable that are + // dominated by the store by the value that was stored. Returns true if the + // module was changed. + bool ProcessVariable(Instruction* var_inst); - // Map from SSA Variable to its single store - std::unordered_map ssa_var2store_; + // Collects all of the uses of |var_inst| into |uses|. This looks through + // OpObjectCopy's that copy the address of the variable, and collects those + // uses as well. + void FindUses(const Instruction* var_inst, + std::vector* uses) const; - // Map from store to its ordinal position in its block. - std::unordered_map store2idx_; + // Returns a store to |var_inst| if + // - it is a store to the entire variable, + // - and there are no other instructions that may modify |var_inst|. + Instruction* FindSingleStoreAndCheckUses( + Instruction* var_inst, const std::vector& users) const; - // Map from store to its block. - std::unordered_map store2blk_; + // Returns true if the address that results from |inst| may be used as a base + // address in a store instruction or may be used to compute the base address + // of a store instruction. + bool FeedsAStore(Instruction* inst) const; - // Set of non-SSA Variables - std::unordered_set non_ssa_vars_; - - // Variables with only supported references, ie. loads and stores using - // variable directly or through non-ptr access chains. - std::unordered_set supported_ref_ptrs_; - - // CFG Predecessors - std::unordered_map> - predecessors_map_; - - // CFG Successors - std::unordered_map> - successors_map_; - - // CFG Augmented Predecessors - std::unordered_map> - augmented_predecessors_map_; - - // CFG Augmented Successors - std::unordered_map> - augmented_successors_map_; - - // Immediate Dominator Map - // If block has no idom it points to itself. - std::unordered_map idom_; + // Replaces all of the loads in |uses| by the value stored in |store_inst|. + // The load instructions are then killed. + bool RewriteLoads(Instruction* store_inst, + const std::vector& uses); // Extensions supported by this pass. std::unordered_set extensions_whitelist_; @@ -146,4 +98,4 @@ class LocalSingleStoreElimPass : public MemPass { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_LOCAL_SINGLE_STORE_ELIM_PASS_H_ +#endif // SOURCE_OPT_LOCAL_SINGLE_STORE_ELIM_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/local_ssa_elim_pass.cpp b/3rdparty/spirv-tools/source/opt/local_ssa_elim_pass.cpp index 14c14bd3b..ec7326ed0 100644 --- a/3rdparty/spirv-tools/source/opt/local_ssa_elim_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/local_ssa_elim_pass.cpp @@ -14,22 +14,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "local_ssa_elim_pass.h" +#include "source/opt/local_ssa_elim_pass.h" -#include "cfa.h" -#include "iterator.h" -#include "ssa_rewrite_pass.h" +#include "source/cfa.h" +#include "source/opt/iterator.h" +#include "source/opt/ssa_rewrite_pass.h" namespace spvtools { namespace opt { -void LocalMultiStoreElimPass::Initialize(ir::IRContext* c) { - InitializeProcessing(c); - - // Initialize extension whitelist - InitExtensions(); -} - bool LocalMultiStoreElimPass::AllExtensionsSupported() const { // If any extension not in whitelist, return false for (auto& ei : get_module()->extensions()) { @@ -54,17 +47,18 @@ Pass::Status LocalMultiStoreElimPass::ProcessImpl() { // Do not process if any disallowed extensions are enabled if (!AllExtensionsSupported()) return Status::SuccessWithoutChange; // Process functions - ProcessFunction pfn = [this](ir::Function* fp) { + ProcessFunction pfn = [this](Function* fp) { return SSARewriter(this).RewriteFunctionIntoSSA(fp); }; bool modified = ProcessEntryPointCallTree(pfn, get_module()); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } -LocalMultiStoreElimPass::LocalMultiStoreElimPass() {} +LocalMultiStoreElimPass::LocalMultiStoreElimPass() = default; -Pass::Status LocalMultiStoreElimPass::Process(ir::IRContext* c) { - Initialize(c); +Pass::Status LocalMultiStoreElimPass::Process() { + // Initialize extension whitelist + InitExtensions(); return ProcessImpl(); } diff --git a/3rdparty/spirv-tools/source/opt/local_ssa_elim_pass.h b/3rdparty/spirv-tools/source/opt/local_ssa_elim_pass.h index c3f70f62c..63d3c33ba 100644 --- a/3rdparty/spirv-tools/source/opt/local_ssa_elim_pass.h +++ b/3rdparty/spirv-tools/source/opt/local_ssa_elim_pass.h @@ -14,39 +14,41 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_LOCAL_SSA_ELIM_PASS_H_ -#define LIBSPIRV_OPT_LOCAL_SSA_ELIM_PASS_H_ +#ifndef SOURCE_OPT_LOCAL_SSA_ELIM_PASS_H_ +#define SOURCE_OPT_LOCAL_SSA_ELIM_PASS_H_ #include #include #include +#include #include #include #include +#include -#include "basic_block.h" -#include "def_use_manager.h" -#include "mem_pass.h" -#include "module.h" +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { // See optimizer.hpp for documentation. class LocalMultiStoreElimPass : public MemPass { - using cbb_ptr = const ir::BasicBlock*; + using cbb_ptr = const BasicBlock*; public: using GetBlocksFunction = - std::function*(const ir::BasicBlock*)>; + std::function*(const BasicBlock*)>; LocalMultiStoreElimPass(); - const char* name() const override { return "eliminate-local-multi-store"; } - Status Process(ir::IRContext* c) override; - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping; + const char* name() const override { return "eliminate-local-multi-store"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping; } private: @@ -56,7 +58,6 @@ class LocalMultiStoreElimPass : public MemPass { // Return true if all extensions in this module are allowed by this pass. bool AllExtensionsSupported() const; - void Initialize(ir::IRContext* c); Pass::Status ProcessImpl(); // Extensions supported by this pass. @@ -66,4 +67,4 @@ class LocalMultiStoreElimPass : public MemPass { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_LOCAL_SSA_ELIM_PASS_H_ +#endif // SOURCE_OPT_LOCAL_SSA_ELIM_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/log.h b/3rdparty/spirv-tools/source/opt/log.h index 70ae223c0..f87cbf381 100644 --- a/3rdparty/spirv-tools/source/opt/log.h +++ b/3rdparty/spirv-tools/source/opt/log.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SPIRV_TOOLS_LOG_H_ -#define SPIRV_TOOLS_LOG_H_ +#ifndef SOURCE_OPT_LOG_H_ +#define SOURCE_OPT_LOG_H_ #include #include @@ -228,4 +228,4 @@ static_assert(PP_NARGS(0, 0, 0, 0, 0) == 5, "PP_NARGS macro error"); static_assert(PP_NARGS(1 + 1, 2, 3 / 3) == 3, "PP_NARGS macro error"); static_assert(PP_NARGS((1, 1), 2, (3, 3)) == 3, "PP_NARGS macro error"); -#endif // SPIRV_TOOLS_LOG_H_ +#endif // SOURCE_OPT_LOG_H_ diff --git a/3rdparty/spirv-tools/source/opt/loop_dependence.cpp b/3rdparty/spirv-tools/source/opt/loop_dependence.cpp new file mode 100644 index 000000000..d8de699bf --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/loop_dependence.cpp @@ -0,0 +1,1675 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/loop_dependence.h" + +#include +#include +#include +#include +#include +#include + +#include "source/opt/instruction.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/scalar_analysis_nodes.h" + +namespace spvtools { +namespace opt { + +using SubscriptPair = std::pair; + +namespace { + +// Calculate the greatest common divisor of a & b using Stein's algorithm. +// https://en.wikipedia.org/wiki/Binary_GCD_algorithm +int64_t GreatestCommonDivisor(int64_t a, int64_t b) { + // Simple cases + if (a == b) { + return a; + } else if (a == 0) { + return b; + } else if (b == 0) { + return a; + } + + // Both even + if (a % 2 == 0 && b % 2 == 0) { + return 2 * GreatestCommonDivisor(a / 2, b / 2); + } + + // Even a, odd b + if (a % 2 == 0 && b % 2 == 1) { + return GreatestCommonDivisor(a / 2, b); + } + + // Odd a, even b + if (a % 2 == 1 && b % 2 == 0) { + return GreatestCommonDivisor(a, b / 2); + } + + // Both odd, reduce the larger argument + if (a > b) { + return GreatestCommonDivisor((a - b) / 2, b); + } else { + return GreatestCommonDivisor((b - a) / 2, a); + } +} + +// Check if node is affine, ie in the form: a0*i0 + a1*i1 + ... an*in + c +// and contains only the following types of nodes: SERecurrentNode, SEAddNode +// and SEConstantNode +bool IsInCorrectFormForGCDTest(SENode* node) { + bool children_ok = true; + + if (auto add_node = node->AsSEAddNode()) { + for (auto child : add_node->GetChildren()) { + children_ok &= IsInCorrectFormForGCDTest(child); + } + } + + bool this_ok = node->AsSERecurrentNode() || node->AsSEAddNode() || + node->AsSEConstantNode(); + + return children_ok && this_ok; +} + +// If |node| is an SERecurrentNode then returns |node| or if |node| is an +// SEAddNode returns a vector of SERecurrentNode that are its children. +std::vector GetAllTopLevelRecurrences(SENode* node) { + auto nodes = std::vector{}; + if (auto recurrent_node = node->AsSERecurrentNode()) { + nodes.push_back(recurrent_node); + } + + if (auto add_node = node->AsSEAddNode()) { + for (auto child : add_node->GetChildren()) { + auto child_nodes = GetAllTopLevelRecurrences(child); + nodes.insert(nodes.end(), child_nodes.begin(), child_nodes.end()); + } + } + + return nodes; +} + +// If |node| is an SEConstantNode then returns |node| or if |node| is an +// SEAddNode returns a vector of SEConstantNode that are its children. +std::vector GetAllTopLevelConstants(SENode* node) { + auto nodes = std::vector{}; + if (auto recurrent_node = node->AsSEConstantNode()) { + nodes.push_back(recurrent_node); + } + + if (auto add_node = node->AsSEAddNode()) { + for (auto child : add_node->GetChildren()) { + auto child_nodes = GetAllTopLevelConstants(child); + nodes.insert(nodes.end(), child_nodes.begin(), child_nodes.end()); + } + } + + return nodes; +} + +bool AreOffsetsAndCoefficientsConstant( + const std::vector& nodes) { + for (auto node : nodes) { + if (!node->GetOffset()->AsSEConstantNode() || + !node->GetOffset()->AsSEConstantNode()) { + return false; + } + } + return true; +} + +// Fold all SEConstantNode that appear in |recurrences| and |constants| into a +// single integer value. +int64_t CalculateConstantTerm(const std::vector& recurrences, + const std::vector& constants) { + int64_t constant_term = 0; + for (auto recurrence : recurrences) { + constant_term += + recurrence->GetOffset()->AsSEConstantNode()->FoldToSingleValue(); + } + + for (auto constant : constants) { + constant_term += constant->FoldToSingleValue(); + } + + return constant_term; +} + +int64_t CalculateGCDFromCoefficients( + const std::vector& recurrences, int64_t running_gcd) { + for (SERecurrentNode* recurrence : recurrences) { + auto coefficient = recurrence->GetCoefficient()->AsSEConstantNode(); + + running_gcd = GreatestCommonDivisor( + running_gcd, std::abs(coefficient->FoldToSingleValue())); + } + + return running_gcd; +} + +// Compare 2 fractions while first normalizing them, e.g. 2/4 and 4/8 will both +// be simplified to 1/2 and then determined to be equal. +bool NormalizeAndCompareFractions(int64_t numerator_0, int64_t denominator_0, + int64_t numerator_1, int64_t denominator_1) { + auto gcd_0 = + GreatestCommonDivisor(std::abs(numerator_0), std::abs(denominator_0)); + auto gcd_1 = + GreatestCommonDivisor(std::abs(numerator_1), std::abs(denominator_1)); + + auto normalized_numerator_0 = numerator_0 / gcd_0; + auto normalized_denominator_0 = denominator_0 / gcd_0; + auto normalized_numerator_1 = numerator_1 / gcd_1; + auto normalized_denominator_1 = denominator_1 / gcd_1; + + return normalized_numerator_0 == normalized_numerator_1 && + normalized_denominator_0 == normalized_denominator_1; +} + +} // namespace + +bool LoopDependenceAnalysis::GetDependence(const Instruction* source, + const Instruction* destination, + DistanceVector* distance_vector) { + // Start off by finding and marking all the loops in |loops_| that are + // irrelevant to the dependence analysis. + MarkUnsusedDistanceEntriesAsIrrelevant(source, destination, distance_vector); + + Instruction* source_access_chain = GetOperandDefinition(source, 0); + Instruction* destination_access_chain = GetOperandDefinition(destination, 0); + + auto num_access_chains = + (source_access_chain->opcode() == SpvOpAccessChain) + + (destination_access_chain->opcode() == SpvOpAccessChain); + + // If neither is an access chain, then they are load/store to a variable. + if (num_access_chains == 0) { + if (source_access_chain != destination_access_chain) { + // Not the same location, report independence + return true; + } else { + // Accessing the same variable + for (auto& entry : distance_vector->GetEntries()) { + entry = DistanceEntry(); + } + return false; + } + } + + // If only one is an access chain, it could be accessing a part of a struct + if (num_access_chains == 1) { + auto source_is_chain = source_access_chain->opcode() == SpvOpAccessChain; + auto access_chain = + source_is_chain ? source_access_chain : destination_access_chain; + auto variable = + source_is_chain ? destination_access_chain : source_access_chain; + + auto location_in_chain = GetOperandDefinition(access_chain, 0); + + if (variable != location_in_chain) { + // Not the same location, report independence + return true; + } else { + // Accessing the same variable + for (auto& entry : distance_vector->GetEntries()) { + entry = DistanceEntry(); + } + return false; + } + } + + // If the access chains aren't collecting from the same structure there is no + // dependence. + Instruction* source_array = GetOperandDefinition(source_access_chain, 0); + Instruction* destination_array = + GetOperandDefinition(destination_access_chain, 0); + + // Nested access chains are not supported yet, bail out. + if (source_array->opcode() == SpvOpAccessChain || + destination_array->opcode() == SpvOpAccessChain) { + for (auto& entry : distance_vector->GetEntries()) { + entry = DistanceEntry(); + } + return false; + } + + if (source_array != destination_array) { + PrintDebug("Proved independence through different arrays."); + return true; + } + + // To handle multiple subscripts we must get every operand in the access + // chains past the first. + std::vector source_subscripts = GetSubscripts(source); + std::vector destination_subscripts = GetSubscripts(destination); + + auto sets_of_subscripts = + PartitionSubscripts(source_subscripts, destination_subscripts); + + auto first_coupled = std::partition( + std::begin(sets_of_subscripts), std::end(sets_of_subscripts), + [](const std::set>& set) { + return set.size() == 1; + }); + + // Go through each subscript testing for independence. + // If any subscript results in independence, we prove independence between the + // load and store. + // If we can't prove independence we store what information we can gather in + // a DistanceVector. + for (auto it = std::begin(sets_of_subscripts); it < first_coupled; ++it) { + auto source_subscript = std::get<0>(*(*it).begin()); + auto destination_subscript = std::get<1>(*(*it).begin()); + + SENode* source_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(source_subscript)); + SENode* destination_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(destination_subscript)); + + // Check the loops are in a form we support. + auto subscript_pair = std::make_pair(source_node, destination_node); + + const Loop* loop = GetLoopForSubscriptPair(subscript_pair); + if (loop) { + if (!IsSupportedLoop(loop)) { + PrintDebug( + "GetDependence found an unsupported loop form. Assuming <=> for " + "loop."); + DistanceEntry* distance_entry = + GetDistanceEntryForSubscriptPair(subscript_pair, distance_vector); + if (distance_entry) { + distance_entry->direction = DistanceEntry::Directions::ALL; + } + continue; + } + } + + // If either node is simplified to a CanNotCompute we can't perform any + // analysis so must assume <=> dependence and return. + if (source_node->GetType() == SENode::CanNotCompute || + destination_node->GetType() == SENode::CanNotCompute) { + // Record the <=> dependence if we can get a DistanceEntry + PrintDebug( + "GetDependence found source_node || destination_node as " + "CanNotCompute. Abandoning evaluation for this subscript."); + DistanceEntry* distance_entry = + GetDistanceEntryForSubscriptPair(subscript_pair, distance_vector); + if (distance_entry) { + distance_entry->direction = DistanceEntry::Directions::ALL; + } + continue; + } + + // We have no induction variables so can apply a ZIV test. + if (IsZIV(subscript_pair)) { + PrintDebug("Found a ZIV subscript pair"); + if (ZIVTest(subscript_pair)) { + PrintDebug("Proved independence with ZIVTest."); + return true; + } + } + + // We have only one induction variable so should attempt an SIV test. + if (IsSIV(subscript_pair)) { + PrintDebug("Found a SIV subscript pair."); + if (SIVTest(subscript_pair, distance_vector)) { + PrintDebug("Proved independence with SIVTest."); + return true; + } + } + + // We have multiple induction variables so should attempt an MIV test. + if (IsMIV(subscript_pair)) { + PrintDebug("Found a MIV subscript pair."); + if (GCDMIVTest(subscript_pair)) { + PrintDebug("Proved independence with the GCD test."); + auto current_loops = CollectLoops(source_node, destination_node); + + for (auto current_loop : current_loops) { + auto distance_entry = + GetDistanceEntryForLoop(current_loop, distance_vector); + distance_entry->direction = DistanceEntry::Directions::NONE; + } + return true; + } + } + } + + for (auto it = first_coupled; it < std::end(sets_of_subscripts); ++it) { + auto coupled_instructions = *it; + std::vector coupled_subscripts{}; + + for (const auto& elem : coupled_instructions) { + auto source_subscript = std::get<0>(elem); + auto destination_subscript = std::get<1>(elem); + + SENode* source_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(source_subscript)); + SENode* destination_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(destination_subscript)); + + coupled_subscripts.push_back({source_node, destination_node}); + } + + auto supported = true; + + for (const auto& subscript : coupled_subscripts) { + auto loops = CollectLoops(std::get<0>(subscript), std::get<1>(subscript)); + + auto is_subscript_supported = + std::all_of(std::begin(loops), std::end(loops), + [this](const Loop* l) { return IsSupportedLoop(l); }); + + supported = supported && is_subscript_supported; + } + + if (DeltaTest(coupled_subscripts, distance_vector)) { + return true; + } + } + + // We were unable to prove independence so must gather all of the direction + // information we found. + PrintDebug( + "Couldn't prove independence.\n" + "All possible direction information has been collected in the input " + "DistanceVector."); + + return false; +} + +bool LoopDependenceAnalysis::ZIVTest( + const std::pair& subscript_pair) { + auto source = std::get<0>(subscript_pair); + auto destination = std::get<1>(subscript_pair); + + PrintDebug("Performing ZIVTest"); + // If source == destination, dependence with direction = and distance 0. + if (source == destination) { + PrintDebug("ZIVTest found EQ dependence."); + return false; + } else { + PrintDebug("ZIVTest found independence."); + // Otherwise we prove independence. + return true; + } +} + +bool LoopDependenceAnalysis::SIVTest( + const std::pair& subscript_pair, + DistanceVector* distance_vector) { + DistanceEntry* distance_entry = + GetDistanceEntryForSubscriptPair(subscript_pair, distance_vector); + if (!distance_entry) { + PrintDebug( + "SIVTest could not find a DistanceEntry for subscript_pair. Exiting"); + } + + SENode* source_node = std::get<0>(subscript_pair); + SENode* destination_node = std::get<1>(subscript_pair); + + int64_t source_induction_count = CountInductionVariables(source_node); + int64_t destination_induction_count = + CountInductionVariables(destination_node); + + // If the source node has no induction variables we can apply a + // WeakZeroSrcTest. + if (source_induction_count == 0) { + PrintDebug("Found source has no induction variable."); + if (WeakZeroSourceSIVTest( + source_node, destination_node->AsSERecurrentNode(), + destination_node->AsSERecurrentNode()->GetCoefficient(), + distance_entry)) { + PrintDebug("Proved independence with WeakZeroSourceSIVTest."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } + } + + // If the destination has no induction variables we can apply a + // WeakZeroDestTest. + if (destination_induction_count == 0) { + PrintDebug("Found destination has no induction variable."); + if (WeakZeroDestinationSIVTest( + source_node->AsSERecurrentNode(), destination_node, + source_node->AsSERecurrentNode()->GetCoefficient(), + distance_entry)) { + PrintDebug("Proved independence with WeakZeroDestinationSIVTest."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } + } + + // We now need to collect the SERecurrentExpr nodes from source and + // destination. We do not handle cases where source or destination have + // multiple SERecurrentExpr nodes. + std::vector source_recurrent_nodes = + source_node->CollectRecurrentNodes(); + std::vector destination_recurrent_nodes = + destination_node->CollectRecurrentNodes(); + + if (source_recurrent_nodes.size() == 1 && + destination_recurrent_nodes.size() == 1) { + PrintDebug("Found source and destination have 1 induction variable."); + SERecurrentNode* source_recurrent_expr = *source_recurrent_nodes.begin(); + SERecurrentNode* destination_recurrent_expr = + *destination_recurrent_nodes.begin(); + + // If the coefficients are identical we can apply a StrongSIVTest. + if (source_recurrent_expr->GetCoefficient() == + destination_recurrent_expr->GetCoefficient()) { + PrintDebug("Found source and destination share coefficient."); + if (StrongSIVTest(source_node, destination_node, + source_recurrent_expr->GetCoefficient(), + distance_entry)) { + PrintDebug("Proved independence with StrongSIVTest"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } + } + + // If the coefficients are of equal magnitude and opposite sign we can + // apply a WeakCrossingSIVTest. + if (source_recurrent_expr->GetCoefficient() == + scalar_evolution_.CreateNegation( + destination_recurrent_expr->GetCoefficient())) { + PrintDebug("Found source coefficient = -destination coefficient."); + if (WeakCrossingSIVTest(source_node, destination_node, + source_recurrent_expr->GetCoefficient(), + distance_entry)) { + PrintDebug("Proved independence with WeakCrossingSIVTest"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } + } + } + + return false; +} + +bool LoopDependenceAnalysis::StrongSIVTest(SENode* source, SENode* destination, + SENode* coefficient, + DistanceEntry* distance_entry) { + PrintDebug("Performing StrongSIVTest."); + // If both source and destination are SERecurrentNodes we can perform tests + // based on distance. + // If either source or destination contain value unknown nodes or if one or + // both are not SERecurrentNodes we must attempt a symbolic test. + std::vector source_value_unknown_nodes = + source->CollectValueUnknownNodes(); + std::vector destination_value_unknown_nodes = + destination->CollectValueUnknownNodes(); + if (source_value_unknown_nodes.size() > 0 || + destination_value_unknown_nodes.size() > 0) { + PrintDebug( + "StrongSIVTest found symbolics. Will attempt SymbolicStrongSIVTest."); + return SymbolicStrongSIVTest(source, destination, coefficient, + distance_entry); + } + + if (!source->AsSERecurrentNode() || !destination->AsSERecurrentNode()) { + PrintDebug( + "StrongSIVTest could not simplify source and destination to " + "SERecurrentNodes so will exit."); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; + } + + // Build an SENode for distance. + std::pair subscript_pair = + std::make_pair(source, destination); + const Loop* subscript_loop = GetLoopForSubscriptPair(subscript_pair); + SENode* source_constant_term = + GetConstantTerm(subscript_loop, source->AsSERecurrentNode()); + SENode* destination_constant_term = + GetConstantTerm(subscript_loop, destination->AsSERecurrentNode()); + if (!source_constant_term || !destination_constant_term) { + PrintDebug( + "StrongSIVTest could not collect the constant terms of either source " + "or destination so will exit."); + return false; + } + SENode* constant_term_delta = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateSubtraction( + destination_constant_term, source_constant_term)); + + // Scalar evolution doesn't perform division, so we must fold to constants and + // do it manually. + // We must check the offset delta and coefficient are constants. + int64_t distance = 0; + SEConstantNode* delta_constant = constant_term_delta->AsSEConstantNode(); + SEConstantNode* coefficient_constant = coefficient->AsSEConstantNode(); + if (delta_constant && coefficient_constant) { + int64_t delta_value = delta_constant->FoldToSingleValue(); + int64_t coefficient_value = coefficient_constant->FoldToSingleValue(); + PrintDebug( + "StrongSIVTest found delta value and coefficient value as constants " + "with values:\n" + "\tdelta value: " + + ToString(delta_value) + + "\n\tcoefficient value: " + ToString(coefficient_value) + "\n"); + // Check if the distance is not integral to try to prove independence. + if (delta_value % coefficient_value != 0) { + PrintDebug( + "StrongSIVTest proved independence through distance not being an " + "integer."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } else { + distance = delta_value / coefficient_value; + PrintDebug("StrongSIV test found distance as " + ToString(distance)); + } + } else { + // If we can't fold delta and coefficient to single values we can't produce + // distance. + // As a result we can't perform the rest of the pass and must assume + // dependence in all directions. + PrintDebug("StrongSIVTest could not produce a distance. Must exit."); + distance_entry->distance = DistanceEntry::Directions::ALL; + return false; + } + + // Next we gather the upper and lower bounds as constants if possible. If + // distance > upper_bound - lower_bound we prove independence. + SENode* lower_bound = GetLowerBound(subscript_loop); + SENode* upper_bound = GetUpperBound(subscript_loop); + if (lower_bound && upper_bound) { + PrintDebug("StrongSIVTest found bounds."); + SENode* bounds = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(upper_bound, lower_bound)); + + if (bounds->GetType() == SENode::SENodeType::Constant) { + int64_t bounds_value = bounds->AsSEConstantNode()->FoldToSingleValue(); + PrintDebug( + "StrongSIVTest found upper_bound - lower_bound as a constant with " + "value " + + ToString(bounds_value)); + + // If the absolute value of the distance is > upper bound - lower bound + // then we prove independence. + if (llabs(distance) > llabs(bounds_value)) { + PrintDebug( + "StrongSIVTest proved independence through distance escaping the " + "loop bounds."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::NONE; + distance_entry->distance = distance; + return true; + } + } + } else { + PrintDebug("StrongSIVTest was unable to gather lower and upper bounds."); + } + + // Otherwise we can get a direction as follows + // { < if distance > 0 + // direction = { = if distance == 0 + // { > if distance < 0 + PrintDebug( + "StrongSIVTest could not prove independence. Gathering direction " + "information."); + if (distance > 0) { + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::LT; + distance_entry->distance = distance; + return false; + } + if (distance == 0) { + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::EQ; + distance_entry->distance = 0; + return false; + } + if (distance < 0) { + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::GT; + distance_entry->distance = distance; + return false; + } + + // We were unable to prove independence or discern any additional information + // Must assume <=> direction. + PrintDebug( + "StrongSIVTest was unable to determine any dependence information."); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; +} + +bool LoopDependenceAnalysis::SymbolicStrongSIVTest( + SENode* source, SENode* destination, SENode* coefficient, + DistanceEntry* distance_entry) { + PrintDebug("Performing SymbolicStrongSIVTest."); + SENode* source_destination_delta = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(source, destination)); + // By cancelling out the induction variables by subtracting the source and + // destination we can produce an expression of symbolics and constants. This + // expression can be compared to the loop bounds to find if the offset is + // outwith the bounds. + std::pair subscript_pair = + std::make_pair(source, destination); + const Loop* subscript_loop = GetLoopForSubscriptPair(subscript_pair); + if (IsProvablyOutsideOfLoopBounds(subscript_loop, source_destination_delta, + coefficient)) { + PrintDebug( + "SymbolicStrongSIVTest proved independence through loop bounds."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } + // We were unable to prove independence or discern any additional information. + // Must assume <=> direction. + PrintDebug( + "SymbolicStrongSIVTest was unable to determine any dependence " + "information."); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; +} + +bool LoopDependenceAnalysis::WeakZeroSourceSIVTest( + SENode* source, SERecurrentNode* destination, SENode* coefficient, + DistanceEntry* distance_entry) { + PrintDebug("Performing WeakZeroSourceSIVTest."); + std::pair subscript_pair = + std::make_pair(source, destination); + const Loop* subscript_loop = GetLoopForSubscriptPair(subscript_pair); + // Build an SENode for distance. + SENode* destination_constant_term = + GetConstantTerm(subscript_loop, destination); + SENode* delta = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(source, destination_constant_term)); + + // Scalar evolution doesn't perform division, so we must fold to constants and + // do it manually. + int64_t distance = 0; + SEConstantNode* delta_constant = delta->AsSEConstantNode(); + SEConstantNode* coefficient_constant = coefficient->AsSEConstantNode(); + if (delta_constant && coefficient_constant) { + PrintDebug( + "WeakZeroSourceSIVTest folding delta and coefficient to constants."); + int64_t delta_value = delta_constant->FoldToSingleValue(); + int64_t coefficient_value = coefficient_constant->FoldToSingleValue(); + // Check if the distance is not integral. + if (delta_value % coefficient_value != 0) { + PrintDebug( + "WeakZeroSourceSIVTest proved independence through distance not " + "being an integer."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } else { + distance = delta_value / coefficient_value; + PrintDebug( + "WeakZeroSourceSIVTest calculated distance with the following " + "values\n" + "\tdelta value: " + + ToString(delta_value) + + "\n\tcoefficient value: " + ToString(coefficient_value) + + "\n\tdistance: " + ToString(distance) + "\n"); + } + } else { + PrintDebug( + "WeakZeroSourceSIVTest was unable to fold delta and coefficient to " + "constants."); + } + + // If we can prove the distance is outside the bounds we prove independence. + SEConstantNode* lower_bound = + GetLowerBound(subscript_loop)->AsSEConstantNode(); + SEConstantNode* upper_bound = + GetUpperBound(subscript_loop)->AsSEConstantNode(); + if (lower_bound && upper_bound) { + PrintDebug("WeakZeroSourceSIVTest found bounds as SEConstantNodes."); + int64_t lower_bound_value = lower_bound->FoldToSingleValue(); + int64_t upper_bound_value = upper_bound->FoldToSingleValue(); + if (!IsWithinBounds(llabs(distance), lower_bound_value, + upper_bound_value)) { + PrintDebug( + "WeakZeroSourceSIVTest proved independence through distance escaping " + "the loop bounds."); + PrintDebug( + "Bound values were as follow\n" + "\tlower bound value: " + + ToString(lower_bound_value) + + "\n\tupper bound value: " + ToString(upper_bound_value) + + "\n\tdistance value: " + ToString(distance) + "\n"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::NONE; + distance_entry->distance = distance; + return true; + } + } else { + PrintDebug( + "WeakZeroSourceSIVTest was unable to find lower and upper bound as " + "SEConstantNodes."); + } + + // Now we want to see if we can detect to peel the first or last iterations. + + // We get the FirstTripValue as GetFirstTripInductionNode() + + // GetConstantTerm(destination) + SENode* first_trip_SENode = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateAddNode( + GetFirstTripInductionNode(subscript_loop), + GetConstantTerm(subscript_loop, destination))); + + // If source == FirstTripValue, peel_first. + if (first_trip_SENode) { + PrintDebug("WeakZeroSourceSIVTest built first_trip_SENode."); + if (first_trip_SENode->AsSEConstantNode()) { + PrintDebug( + "WeakZeroSourceSIVTest has found first_trip_SENode as an " + "SEConstantNode with value: " + + ToString(first_trip_SENode->AsSEConstantNode()->FoldToSingleValue()) + + "\n"); + } + if (source == first_trip_SENode) { + // We have found that peeling the first iteration will break dependency. + PrintDebug( + "WeakZeroSourceSIVTest has found peeling first iteration will break " + "dependency"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::PEEL; + distance_entry->peel_first = true; + return false; + } + } else { + PrintDebug("WeakZeroSourceSIVTest was unable to build first_trip_SENode"); + } + + // We get the LastTripValue as GetFinalTripInductionNode(coefficient) + + // GetConstantTerm(destination) + SENode* final_trip_SENode = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateAddNode( + GetFinalTripInductionNode(subscript_loop, coefficient), + GetConstantTerm(subscript_loop, destination))); + + // If source == LastTripValue, peel_last. + if (final_trip_SENode) { + PrintDebug("WeakZeroSourceSIVTest built final_trip_SENode."); + if (first_trip_SENode->AsSEConstantNode()) { + PrintDebug( + "WeakZeroSourceSIVTest has found final_trip_SENode as an " + "SEConstantNode with value: " + + ToString(final_trip_SENode->AsSEConstantNode()->FoldToSingleValue()) + + "\n"); + } + if (source == final_trip_SENode) { + // We have found that peeling the last iteration will break dependency. + PrintDebug( + "WeakZeroSourceSIVTest has found peeling final iteration will break " + "dependency"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::PEEL; + distance_entry->peel_last = true; + return false; + } + } else { + PrintDebug("WeakZeroSourceSIVTest was unable to build final_trip_SENode"); + } + + // We were unable to prove independence or discern any additional information. + // Must assume <=> direction. + PrintDebug( + "WeakZeroSourceSIVTest was unable to determine any dependence " + "information."); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; +} + +bool LoopDependenceAnalysis::WeakZeroDestinationSIVTest( + SERecurrentNode* source, SENode* destination, SENode* coefficient, + DistanceEntry* distance_entry) { + PrintDebug("Performing WeakZeroDestinationSIVTest."); + // Build an SENode for distance. + std::pair subscript_pair = + std::make_pair(source, destination); + const Loop* subscript_loop = GetLoopForSubscriptPair(subscript_pair); + SENode* source_constant_term = GetConstantTerm(subscript_loop, source); + SENode* delta = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(destination, source_constant_term)); + + // Scalar evolution doesn't perform division, so we must fold to constants and + // do it manually. + int64_t distance = 0; + SEConstantNode* delta_constant = delta->AsSEConstantNode(); + SEConstantNode* coefficient_constant = coefficient->AsSEConstantNode(); + if (delta_constant && coefficient_constant) { + PrintDebug( + "WeakZeroDestinationSIVTest folding delta and coefficient to " + "constants."); + int64_t delta_value = delta_constant->FoldToSingleValue(); + int64_t coefficient_value = coefficient_constant->FoldToSingleValue(); + // Check if the distance is not integral. + if (delta_value % coefficient_value != 0) { + PrintDebug( + "WeakZeroDestinationSIVTest proved independence through distance not " + "being an integer."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } else { + distance = delta_value / coefficient_value; + PrintDebug( + "WeakZeroDestinationSIVTest calculated distance with the following " + "values\n" + "\tdelta value: " + + ToString(delta_value) + + "\n\tcoefficient value: " + ToString(coefficient_value) + + "\n\tdistance: " + ToString(distance) + "\n"); + } + } else { + PrintDebug( + "WeakZeroDestinationSIVTest was unable to fold delta and coefficient " + "to constants."); + } + + // If we can prove the distance is outside the bounds we prove independence. + SEConstantNode* lower_bound = + GetLowerBound(subscript_loop)->AsSEConstantNode(); + SEConstantNode* upper_bound = + GetUpperBound(subscript_loop)->AsSEConstantNode(); + if (lower_bound && upper_bound) { + PrintDebug("WeakZeroDestinationSIVTest found bounds as SEConstantNodes."); + int64_t lower_bound_value = lower_bound->FoldToSingleValue(); + int64_t upper_bound_value = upper_bound->FoldToSingleValue(); + if (!IsWithinBounds(llabs(distance), lower_bound_value, + upper_bound_value)) { + PrintDebug( + "WeakZeroDestinationSIVTest proved independence through distance " + "escaping the loop bounds."); + PrintDebug( + "Bound values were as follows\n" + "\tlower bound value: " + + ToString(lower_bound_value) + + "\n\tupper bound value: " + ToString(upper_bound_value) + + "\n\tdistance value: " + ToString(distance)); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::NONE; + distance_entry->distance = distance; + return true; + } + } else { + PrintDebug( + "WeakZeroDestinationSIVTest was unable to find lower and upper bound " + "as SEConstantNodes."); + } + + // Now we want to see if we can detect to peel the first or last iterations. + + // We get the FirstTripValue as GetFirstTripInductionNode() + + // GetConstantTerm(source) + SENode* first_trip_SENode = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateAddNode(GetFirstTripInductionNode(subscript_loop), + GetConstantTerm(subscript_loop, source))); + + // If destination == FirstTripValue, peel_first. + if (first_trip_SENode) { + PrintDebug("WeakZeroDestinationSIVTest built first_trip_SENode."); + if (first_trip_SENode->AsSEConstantNode()) { + PrintDebug( + "WeakZeroDestinationSIVTest has found first_trip_SENode as an " + "SEConstantNode with value: " + + ToString(first_trip_SENode->AsSEConstantNode()->FoldToSingleValue()) + + "\n"); + } + if (destination == first_trip_SENode) { + // We have found that peeling the first iteration will break dependency. + PrintDebug( + "WeakZeroDestinationSIVTest has found peeling first iteration will " + "break dependency"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::PEEL; + distance_entry->peel_first = true; + return false; + } + } else { + PrintDebug( + "WeakZeroDestinationSIVTest was unable to build first_trip_SENode"); + } + + // We get the LastTripValue as GetFinalTripInductionNode(coefficient) + + // GetConstantTerm(source) + SENode* final_trip_SENode = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateAddNode( + GetFinalTripInductionNode(subscript_loop, coefficient), + GetConstantTerm(subscript_loop, source))); + + // If destination == LastTripValue, peel_last. + if (final_trip_SENode) { + PrintDebug("WeakZeroDestinationSIVTest built final_trip_SENode."); + if (final_trip_SENode->AsSEConstantNode()) { + PrintDebug( + "WeakZeroDestinationSIVTest has found final_trip_SENode as an " + "SEConstantNode with value: " + + ToString(final_trip_SENode->AsSEConstantNode()->FoldToSingleValue()) + + "\n"); + } + if (destination == final_trip_SENode) { + // We have found that peeling the last iteration will break dependency. + PrintDebug( + "WeakZeroDestinationSIVTest has found peeling final iteration will " + "break dependency"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::PEEL; + distance_entry->peel_last = true; + return false; + } + } else { + PrintDebug( + "WeakZeroDestinationSIVTest was unable to build final_trip_SENode"); + } + + // We were unable to prove independence or discern any additional information. + // Must assume <=> direction. + PrintDebug( + "WeakZeroDestinationSIVTest was unable to determine any dependence " + "information."); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; +} + +bool LoopDependenceAnalysis::WeakCrossingSIVTest( + SENode* source, SENode* destination, SENode* coefficient, + DistanceEntry* distance_entry) { + PrintDebug("Performing WeakCrossingSIVTest."); + // We currently can't handle symbolic WeakCrossingSIVTests. If either source + // or destination are not SERecurrentNodes we must exit. + if (!source->AsSERecurrentNode() || !destination->AsSERecurrentNode()) { + PrintDebug( + "WeakCrossingSIVTest found source or destination != SERecurrentNode. " + "Exiting"); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; + } + + // Build an SENode for distance. + SENode* offset_delta = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateSubtraction( + destination->AsSERecurrentNode()->GetOffset(), + source->AsSERecurrentNode()->GetOffset())); + + // Scalar evolution doesn't perform division, so we must fold to constants and + // do it manually. + int64_t distance = 0; + SEConstantNode* delta_constant = offset_delta->AsSEConstantNode(); + SEConstantNode* coefficient_constant = coefficient->AsSEConstantNode(); + if (delta_constant && coefficient_constant) { + PrintDebug( + "WeakCrossingSIVTest folding offset_delta and coefficient to " + "constants."); + int64_t delta_value = delta_constant->FoldToSingleValue(); + int64_t coefficient_value = coefficient_constant->FoldToSingleValue(); + // Check if the distance is not integral or if it has a non-integral part + // equal to 1/2. + if (delta_value % (2 * coefficient_value) != 0 && + static_cast(delta_value % (2 * coefficient_value)) / + static_cast(2 * coefficient_value) != + 0.5) { + PrintDebug( + "WeakCrossingSIVTest proved independence through distance escaping " + "the loop bounds."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } else { + distance = delta_value / (2 * coefficient_value); + } + + if (distance == 0) { + PrintDebug("WeakCrossingSIVTest found EQ dependence."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::EQ; + distance_entry->distance = 0; + return false; + } + } else { + PrintDebug( + "WeakCrossingSIVTest was unable to fold offset_delta and coefficient " + "to constants."); + } + + // We were unable to prove independence or discern any additional information. + // Must assume <=> direction. + PrintDebug( + "WeakCrossingSIVTest was unable to determine any dependence " + "information."); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; +} + +// Perform the GCD test if both, the source and the destination nodes, are in +// the form a0*i0 + a1*i1 + ... an*in + c. +bool LoopDependenceAnalysis::GCDMIVTest( + const std::pair& subscript_pair) { + auto source = std::get<0>(subscript_pair); + auto destination = std::get<1>(subscript_pair); + + // Bail out if source/destination is in an unexpected form. + if (!IsInCorrectFormForGCDTest(source) || + !IsInCorrectFormForGCDTest(destination)) { + return false; + } + + auto source_recurrences = GetAllTopLevelRecurrences(source); + auto dest_recurrences = GetAllTopLevelRecurrences(destination); + + // Bail out if all offsets and coefficients aren't constant. + if (!AreOffsetsAndCoefficientsConstant(source_recurrences) || + !AreOffsetsAndCoefficientsConstant(dest_recurrences)) { + return false; + } + + // Calculate the GCD of all coefficients. + auto source_constants = GetAllTopLevelConstants(source); + int64_t source_constant = + CalculateConstantTerm(source_recurrences, source_constants); + + auto dest_constants = GetAllTopLevelConstants(destination); + int64_t destination_constant = + CalculateConstantTerm(dest_recurrences, dest_constants); + + int64_t delta = std::abs(source_constant - destination_constant); + + int64_t running_gcd = 0; + + running_gcd = CalculateGCDFromCoefficients(source_recurrences, running_gcd); + running_gcd = CalculateGCDFromCoefficients(dest_recurrences, running_gcd); + + return delta % running_gcd != 0; +} + +using PartitionedSubscripts = + std::vector>>; +PartitionedSubscripts LoopDependenceAnalysis::PartitionSubscripts( + const std::vector& source_subscripts, + const std::vector& destination_subscripts) { + PartitionedSubscripts partitions{}; + + auto num_subscripts = source_subscripts.size(); + + // Create initial partitions with one subscript pair per partition. + for (size_t i = 0; i < num_subscripts; ++i) { + partitions.push_back({{source_subscripts[i], destination_subscripts[i]}}); + } + + // Iterate over the loops to create all partitions + for (auto loop : loops_) { + int64_t k = -1; + + for (size_t j = 0; j < partitions.size(); ++j) { + auto& current_partition = partitions[j]; + + // Does |loop| appear in |current_partition| + auto it = std::find_if( + current_partition.begin(), current_partition.end(), + [loop, + this](const std::pair& elem) -> bool { + auto source_recurrences = + scalar_evolution_.AnalyzeInstruction(std::get<0>(elem)) + ->CollectRecurrentNodes(); + auto destination_recurrences = + scalar_evolution_.AnalyzeInstruction(std::get<1>(elem)) + ->CollectRecurrentNodes(); + + source_recurrences.insert(source_recurrences.end(), + destination_recurrences.begin(), + destination_recurrences.end()); + + auto loops_in_pair = CollectLoops(source_recurrences); + auto end_it = loops_in_pair.end(); + + return std::find(loops_in_pair.begin(), end_it, loop) != end_it; + }); + + auto has_loop = it != current_partition.end(); + + if (has_loop) { + if (k == -1) { + k = j; + } else { + // Add |partitions[j]| to |partitions[k]| and discard |partitions[j]| + partitions[static_cast(k)].insert(current_partition.begin(), + current_partition.end()); + current_partition.clear(); + } + } + } + } + + // Remove empty (discarded) partitions + partitions.erase( + std::remove_if( + partitions.begin(), partitions.end(), + [](const std::set>& partition) { + return partition.empty(); + }), + partitions.end()); + + return partitions; +} + +Constraint* LoopDependenceAnalysis::IntersectConstraints( + Constraint* constraint_0, Constraint* constraint_1, + const SENode* lower_bound, const SENode* upper_bound) { + if (constraint_0->AsDependenceNone()) { + return constraint_1; + } else if (constraint_1->AsDependenceNone()) { + return constraint_0; + } + + // Both constraints are distances. Either the same distance or independent. + if (constraint_0->AsDependenceDistance() && + constraint_1->AsDependenceDistance()) { + auto dist_0 = constraint_0->AsDependenceDistance(); + auto dist_1 = constraint_1->AsDependenceDistance(); + + if (*dist_0->GetDistance() == *dist_1->GetDistance()) { + return constraint_0; + } else { + return make_constraint(); + } + } + + // Both constraints are points. Either the same point or independent. + if (constraint_0->AsDependencePoint() && constraint_1->AsDependencePoint()) { + auto point_0 = constraint_0->AsDependencePoint(); + auto point_1 = constraint_1->AsDependencePoint(); + + if (*point_0->GetSource() == *point_1->GetSource() && + *point_0->GetDestination() == *point_1->GetDestination()) { + return constraint_0; + } else { + return make_constraint(); + } + } + + // Both constraints are lines/distances. + if ((constraint_0->AsDependenceDistance() || + constraint_0->AsDependenceLine()) && + (constraint_1->AsDependenceDistance() || + constraint_1->AsDependenceLine())) { + auto is_distance_0 = constraint_0->AsDependenceDistance() != nullptr; + auto is_distance_1 = constraint_1->AsDependenceDistance() != nullptr; + + auto a0 = is_distance_0 ? scalar_evolution_.CreateConstant(1) + : constraint_0->AsDependenceLine()->GetA(); + auto b0 = is_distance_0 ? scalar_evolution_.CreateConstant(-1) + : constraint_0->AsDependenceLine()->GetB(); + auto c0 = + is_distance_0 + ? scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateNegation( + constraint_0->AsDependenceDistance()->GetDistance())) + : constraint_0->AsDependenceLine()->GetC(); + + auto a1 = is_distance_1 ? scalar_evolution_.CreateConstant(1) + : constraint_1->AsDependenceLine()->GetA(); + auto b1 = is_distance_1 ? scalar_evolution_.CreateConstant(-1) + : constraint_1->AsDependenceLine()->GetB(); + auto c1 = + is_distance_1 + ? scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateNegation( + constraint_1->AsDependenceDistance()->GetDistance())) + : constraint_1->AsDependenceLine()->GetC(); + + if (a0->AsSEConstantNode() && b0->AsSEConstantNode() && + c0->AsSEConstantNode() && a1->AsSEConstantNode() && + b1->AsSEConstantNode() && c1->AsSEConstantNode()) { + auto constant_a0 = a0->AsSEConstantNode()->FoldToSingleValue(); + auto constant_b0 = b0->AsSEConstantNode()->FoldToSingleValue(); + auto constant_c0 = c0->AsSEConstantNode()->FoldToSingleValue(); + + auto constant_a1 = a1->AsSEConstantNode()->FoldToSingleValue(); + auto constant_b1 = b1->AsSEConstantNode()->FoldToSingleValue(); + auto constant_c1 = c1->AsSEConstantNode()->FoldToSingleValue(); + + // a & b can't both be zero, otherwise it wouldn't be line. + if (NormalizeAndCompareFractions(constant_a0, constant_b0, constant_a1, + constant_b1)) { + // Slopes are equal, either parallel lines or the same line. + + if (constant_b0 == 0 && constant_b1 == 0) { + if (NormalizeAndCompareFractions(constant_c0, constant_a0, + constant_c1, constant_a1)) { + return constraint_0; + } + + return make_constraint(); + } else if (NormalizeAndCompareFractions(constant_c0, constant_b0, + constant_c1, constant_b1)) { + // Same line. + return constraint_0; + } else { + // Parallel lines can't intersect, report independence. + return make_constraint(); + } + + } else { + // Lines are not parallel, therefore, they must intersect. + + // Calculate intersection. + if (upper_bound->AsSEConstantNode() && + lower_bound->AsSEConstantNode()) { + auto constant_lower_bound = + lower_bound->AsSEConstantNode()->FoldToSingleValue(); + auto constant_upper_bound = + upper_bound->AsSEConstantNode()->FoldToSingleValue(); + + auto up = constant_b1 * constant_c0 - constant_b0 * constant_c1; + // Both b or both a can't be 0, so down is never 0 + // otherwise would have entered the parallel line section. + auto down = constant_b1 * constant_a0 - constant_b0 * constant_a1; + + auto x_coord = up / down; + + int64_t y_coord = 0; + int64_t arg1 = 0; + int64_t const_b_to_use = 0; + + if (constant_b1 != 0) { + arg1 = constant_c1 - constant_a1 * x_coord; + y_coord = arg1 / constant_b1; + const_b_to_use = constant_b1; + } else if (constant_b0 != 0) { + arg1 = constant_c0 - constant_a0 * x_coord; + y_coord = arg1 / constant_b0; + const_b_to_use = constant_b0; + } + + if (up % down == 0 && + arg1 % const_b_to_use == 0 && // Coordinates are integers. + constant_lower_bound <= + x_coord && // x_coord is within loop bounds. + x_coord <= constant_upper_bound && + constant_lower_bound <= + y_coord && // y_coord is within loop bounds. + y_coord <= constant_upper_bound) { + // Lines intersect at integer coordinates. + return make_constraint( + scalar_evolution_.CreateConstant(x_coord), + scalar_evolution_.CreateConstant(y_coord), + constraint_0->GetLoop()); + + } else { + return make_constraint(); + } + + } else { + // Not constants, bail out. + return make_constraint(); + } + } + + } else { + // Not constants, bail out. + return make_constraint(); + } + } + + // One constraint is a line/distance and the other is a point. + if ((constraint_0->AsDependencePoint() && + (constraint_1->AsDependenceLine() || + constraint_1->AsDependenceDistance())) || + (constraint_1->AsDependencePoint() && + (constraint_0->AsDependenceLine() || + constraint_0->AsDependenceDistance()))) { + auto point_0 = constraint_0->AsDependencePoint() != nullptr; + + auto point = point_0 ? constraint_0->AsDependencePoint() + : constraint_1->AsDependencePoint(); + + auto line_or_distance = point_0 ? constraint_1 : constraint_0; + + auto is_distance = line_or_distance->AsDependenceDistance() != nullptr; + + auto a = is_distance ? scalar_evolution_.CreateConstant(1) + : line_or_distance->AsDependenceLine()->GetA(); + auto b = is_distance ? scalar_evolution_.CreateConstant(-1) + : line_or_distance->AsDependenceLine()->GetB(); + auto c = + is_distance + ? scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateNegation( + line_or_distance->AsDependenceDistance()->GetDistance())) + : line_or_distance->AsDependenceLine()->GetC(); + + auto x = point->GetSource(); + auto y = point->GetDestination(); + + if (a->AsSEConstantNode() && b->AsSEConstantNode() && + c->AsSEConstantNode() && x->AsSEConstantNode() && + y->AsSEConstantNode()) { + auto constant_a = a->AsSEConstantNode()->FoldToSingleValue(); + auto constant_b = b->AsSEConstantNode()->FoldToSingleValue(); + auto constant_c = c->AsSEConstantNode()->FoldToSingleValue(); + + auto constant_x = x->AsSEConstantNode()->FoldToSingleValue(); + auto constant_y = y->AsSEConstantNode()->FoldToSingleValue(); + + auto left_hand_side = constant_a * constant_x + constant_b * constant_y; + + if (left_hand_side == constant_c) { + // Point is on line, return point + return point_0 ? constraint_0 : constraint_1; + } else { + // Point not on line, report independence (empty constraint). + return make_constraint(); + } + + } else { + // Not constants, bail out. + return make_constraint(); + } + } + + return nullptr; +} + +// Propagate constraints function as described in section 5 of Practical +// Dependence Testing, Goff, Kennedy, Tseng, 1991. +SubscriptPair LoopDependenceAnalysis::PropagateConstraints( + const SubscriptPair& subscript_pair, + const std::vector& constraints) { + SENode* new_first = subscript_pair.first; + SENode* new_second = subscript_pair.second; + + for (auto& constraint : constraints) { + // In the paper this is a[k]. We're extracting the coefficient ('a') of a + // recurrent expression with respect to the loop 'k'. + SENode* coefficient_of_recurrent = + scalar_evolution_.GetCoefficientFromRecurrentTerm( + new_first, constraint->GetLoop()); + + // In the paper this is a'[k]. + SENode* coefficient_of_recurrent_prime = + scalar_evolution_.GetCoefficientFromRecurrentTerm( + new_second, constraint->GetLoop()); + + if (constraint->GetType() == Constraint::Distance) { + DependenceDistance* as_distance = constraint->AsDependenceDistance(); + + // In the paper this is a[k]*d + SENode* rhs = scalar_evolution_.CreateMultiplyNode( + coefficient_of_recurrent, as_distance->GetDistance()); + + // In the paper this is a[k] <- 0 + SENode* zeroed_coefficient = + scalar_evolution_.BuildGraphWithoutRecurrentTerm( + new_first, constraint->GetLoop()); + + // In the paper this is e <- e - a[k]*d. + new_first = scalar_evolution_.CreateSubtraction(zeroed_coefficient, rhs); + new_first = scalar_evolution_.SimplifyExpression(new_first); + + // In the paper this is a'[k] - a[k]. + SENode* new_child = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(coefficient_of_recurrent_prime, + coefficient_of_recurrent)); + + // In the paper this is a'[k]'i[k]. + SERecurrentNode* prime_recurrent = + scalar_evolution_.GetRecurrentTerm(new_second, constraint->GetLoop()); + + if (!prime_recurrent) continue; + + // As we hash the nodes we need to create a new node when we update a + // child. + SENode* new_recurrent = scalar_evolution_.CreateRecurrentExpression( + constraint->GetLoop(), prime_recurrent->GetOffset(), new_child); + // In the paper this is a'[k] <- a'[k] - a[k]. + new_second = scalar_evolution_.UpdateChildNode( + new_second, prime_recurrent, new_recurrent); + } + } + + new_second = scalar_evolution_.SimplifyExpression(new_second); + return std::make_pair(new_first, new_second); +} + +bool LoopDependenceAnalysis::DeltaTest( + const std::vector& coupled_subscripts, + DistanceVector* dv_entry) { + std::vector constraints(loops_.size()); + + std::vector loop_appeared(loops_.size()); + + std::generate(std::begin(constraints), std::end(constraints), + [this]() { return make_constraint(); }); + + // Separate SIV and MIV subscripts + std::vector siv_subscripts{}; + std::vector miv_subscripts{}; + + for (const auto& subscript_pair : coupled_subscripts) { + if (IsSIV(subscript_pair)) { + siv_subscripts.push_back(subscript_pair); + } else { + miv_subscripts.push_back(subscript_pair); + } + } + + // Delta Test + while (!siv_subscripts.empty()) { + std::vector results(siv_subscripts.size()); + + std::vector current_distances( + siv_subscripts.size(), DistanceVector(loops_.size())); + + // Apply SIV test to all SIV subscripts, report independence if any of them + // is independent + std::transform( + std::begin(siv_subscripts), std::end(siv_subscripts), + std::begin(current_distances), std::begin(results), + [this](SubscriptPair& p, DistanceVector& d) { return SIVTest(p, &d); }); + + if (std::accumulate(std::begin(results), std::end(results), false, + std::logical_or{})) { + return true; + } + + // Derive new constraint vector. + std::vector> all_new_constrants{}; + + for (size_t i = 0; i < siv_subscripts.size(); ++i) { + auto loop = GetLoopForSubscriptPair(siv_subscripts[i]); + + auto loop_id = + std::distance(std::begin(loops_), + std::find(std::begin(loops_), std::end(loops_), loop)); + + loop_appeared[loop_id] = true; + auto distance_entry = current_distances[i].GetEntries()[loop_id]; + + if (distance_entry.dependence_information == + DistanceEntry::DependenceInformation::DISTANCE) { + // Construct a DependenceDistance. + auto node = scalar_evolution_.CreateConstant(distance_entry.distance); + + all_new_constrants.push_back( + {make_constraint(node, loop), loop_id}); + } else { + // Construct a DependenceLine. + const auto& subscript_pair = siv_subscripts[i]; + SENode* source_node = std::get<0>(subscript_pair); + SENode* destination_node = std::get<1>(subscript_pair); + + int64_t source_induction_count = CountInductionVariables(source_node); + int64_t destination_induction_count = + CountInductionVariables(destination_node); + + SENode* a = nullptr; + SENode* b = nullptr; + SENode* c = nullptr; + + if (destination_induction_count != 0) { + a = destination_node->AsSERecurrentNode()->GetCoefficient(); + c = scalar_evolution_.CreateNegation( + destination_node->AsSERecurrentNode()->GetOffset()); + } else { + a = scalar_evolution_.CreateConstant(0); + c = scalar_evolution_.CreateNegation(destination_node); + } + + if (source_induction_count != 0) { + b = scalar_evolution_.CreateNegation( + source_node->AsSERecurrentNode()->GetCoefficient()); + c = scalar_evolution_.CreateAddNode( + c, source_node->AsSERecurrentNode()->GetOffset()); + } else { + b = scalar_evolution_.CreateConstant(0); + c = scalar_evolution_.CreateAddNode(c, source_node); + } + + a = scalar_evolution_.SimplifyExpression(a); + b = scalar_evolution_.SimplifyExpression(b); + c = scalar_evolution_.SimplifyExpression(c); + + all_new_constrants.push_back( + {make_constraint(a, b, c, loop), loop_id}); + } + } + + // Calculate the intersection between the new and existing constraints. + std::vector intersection = constraints; + for (const auto& constraint_to_intersect : all_new_constrants) { + auto loop_id = std::get<1>(constraint_to_intersect); + auto loop = loops_[loop_id]; + intersection[loop_id] = IntersectConstraints( + intersection[loop_id], std::get<0>(constraint_to_intersect), + GetLowerBound(loop), GetUpperBound(loop)); + } + + // Report independence if an empty constraint (DependenceEmpty) is found. + auto first_empty = + std::find_if(std::begin(intersection), std::end(intersection), + [](Constraint* constraint) { + return constraint->AsDependenceEmpty() != nullptr; + }); + if (first_empty != std::end(intersection)) { + return true; + } + std::vector new_siv_subscripts{}; + std::vector new_miv_subscripts{}; + + auto equal = + std::equal(std::begin(constraints), std::end(constraints), + std::begin(intersection), + [](Constraint* a, Constraint* b) { return *a == *b; }); + + // If any constraints have changed, propagate them into the rest of the + // subscripts possibly creating new ZIV/SIV subscripts. + if (!equal) { + std::vector new_subscripts(miv_subscripts.size()); + + // Propagate constraints into MIV subscripts + std::transform(std::begin(miv_subscripts), std::end(miv_subscripts), + std::begin(new_subscripts), + [this, &intersection](SubscriptPair& subscript_pair) { + return PropagateConstraints(subscript_pair, + intersection); + }); + + // If a ZIV subscript is returned, apply test, otherwise, update untested + // subscripts. + for (auto& subscript : new_subscripts) { + if (IsZIV(subscript) && ZIVTest(subscript)) { + return true; + } else if (IsSIV(subscript)) { + new_siv_subscripts.push_back(subscript); + } else { + new_miv_subscripts.push_back(subscript); + } + } + } + + // Set new constraints and subscripts to test. + std::swap(siv_subscripts, new_siv_subscripts); + std::swap(miv_subscripts, new_miv_subscripts); + std::swap(constraints, intersection); + } + + // Create the dependence vector from the constraints. + for (size_t i = 0; i < loops_.size(); ++i) { + // Don't touch entries for loops that weren't tested. + if (loop_appeared[i]) { + auto current_constraint = constraints[i]; + auto& current_distance_entry = (*dv_entry).GetEntries()[i]; + + if (auto dependence_distance = + current_constraint->AsDependenceDistance()) { + if (auto constant_node = + dependence_distance->GetDistance()->AsSEConstantNode()) { + current_distance_entry.dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + + current_distance_entry.distance = constant_node->FoldToSingleValue(); + if (current_distance_entry.distance == 0) { + current_distance_entry.direction = DistanceEntry::Directions::EQ; + } else if (current_distance_entry.distance < 0) { + current_distance_entry.direction = DistanceEntry::Directions::GT; + } else { + current_distance_entry.direction = DistanceEntry::Directions::LT; + } + } + } else if (auto dependence_point = + current_constraint->AsDependencePoint()) { + auto source = dependence_point->GetSource(); + auto destination = dependence_point->GetDestination(); + + if (source->AsSEConstantNode() && destination->AsSEConstantNode()) { + current_distance_entry = DistanceEntry( + source->AsSEConstantNode()->FoldToSingleValue(), + destination->AsSEConstantNode()->FoldToSingleValue()); + } + } + } + } + + // Test any remaining MIV subscripts and report independence if found. + std::vector results(miv_subscripts.size()); + + std::transform(std::begin(miv_subscripts), std::end(miv_subscripts), + std::begin(results), + [this](const SubscriptPair& p) { return GCDMIVTest(p); }); + + return std::accumulate(std::begin(results), std::end(results), false, + std::logical_or{}); +} + +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/loop_dependence.h b/3rdparty/spirv-tools/source/opt/loop_dependence.h new file mode 100644 index 000000000..582c8d0ac --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/loop_dependence.h @@ -0,0 +1,558 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_DEPENDENCE_H_ +#define SOURCE_OPT_LOOP_DEPENDENCE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/instruction.h" +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/scalar_analysis.h" + +namespace spvtools { +namespace opt { + +// Stores information about dependence between a load and a store wrt a single +// loop in a loop nest. +// DependenceInformation +// * UNKNOWN if no dependence information can be gathered or is gathered +// for it. +// * DIRECTION if a dependence direction could be found, but not a +// distance. +// * DISTANCE if a dependence distance could be found. +// * PEEL if peeling either the first or last iteration will break +// dependence between the given load and store. +// * IRRELEVANT if it has no effect on the dependence between the given +// load and store. +// +// If peel_first == true, the analysis has found that peeling the first +// iteration of this loop will break dependence. +// +// If peel_last == true, the analysis has found that peeling the last iteration +// of this loop will break dependence. +class DistanceEntry { + public: + enum DependenceInformation { + UNKNOWN = 0, + DIRECTION = 1, + DISTANCE = 2, + PEEL = 3, + IRRELEVANT = 4, + POINT = 5 + }; + enum Directions { + NONE = 0, + LT = 1, + EQ = 2, + LE = 3, + GT = 4, + NE = 5, + GE = 6, + ALL = 7 + }; + DependenceInformation dependence_information; + Directions direction; + int64_t distance; + bool peel_first; + bool peel_last; + int64_t point_x; + int64_t point_y; + + DistanceEntry() + : dependence_information(DependenceInformation::UNKNOWN), + direction(Directions::ALL), + distance(0), + peel_first(false), + peel_last(false), + point_x(0), + point_y(0) {} + + explicit DistanceEntry(Directions direction_) + : dependence_information(DependenceInformation::DIRECTION), + direction(direction_), + distance(0), + peel_first(false), + peel_last(false), + point_x(0), + point_y(0) {} + + DistanceEntry(Directions direction_, int64_t distance_) + : dependence_information(DependenceInformation::DISTANCE), + direction(direction_), + distance(distance_), + peel_first(false), + peel_last(false), + point_x(0), + point_y(0) {} + + DistanceEntry(int64_t x, int64_t y) + : dependence_information(DependenceInformation::POINT), + direction(Directions::ALL), + distance(0), + peel_first(false), + peel_last(false), + point_x(x), + point_y(y) {} + + bool operator==(const DistanceEntry& rhs) const { + return direction == rhs.direction && peel_first == rhs.peel_first && + peel_last == rhs.peel_last && distance == rhs.distance && + point_x == rhs.point_x && point_y == rhs.point_y; + } + + bool operator!=(const DistanceEntry& rhs) const { return !(*this == rhs); } +}; + +// Stores a vector of DistanceEntrys, one per loop in the analysis. +// A DistanceVector holds all of the information gathered in a dependence +// analysis wrt the loops stored in the LoopDependenceAnalysis performing the +// analysis. +class DistanceVector { + public: + explicit DistanceVector(size_t size) : entries(size, DistanceEntry{}) {} + + explicit DistanceVector(std::vector entries_) + : entries(entries_) {} + + DistanceEntry& GetEntry(size_t index) { return entries[index]; } + const DistanceEntry& GetEntry(size_t index) const { return entries[index]; } + + std::vector& GetEntries() { return entries; } + const std::vector& GetEntries() const { return entries; } + + bool operator==(const DistanceVector& rhs) const { + if (entries.size() != rhs.entries.size()) { + return false; + } + for (size_t i = 0; i < entries.size(); ++i) { + if (entries[i] != rhs.entries[i]) { + return false; + } + } + return true; + } + bool operator!=(const DistanceVector& rhs) const { return !(*this == rhs); } + + private: + std::vector entries; +}; + +class DependenceLine; +class DependenceDistance; +class DependencePoint; +class DependenceNone; +class DependenceEmpty; + +class Constraint { + public: + explicit Constraint(const Loop* loop) : loop_(loop) {} + enum ConstraintType { Line, Distance, Point, None, Empty }; + + virtual ConstraintType GetType() const = 0; + + virtual ~Constraint() {} + + // Get the loop this constraint belongs to. + const Loop* GetLoop() const { return loop_; } + + bool operator==(const Constraint& other) const; + + bool operator!=(const Constraint& other) const; + +#define DeclareCastMethod(target) \ + virtual target* As##target() { return nullptr; } \ + virtual const target* As##target() const { return nullptr; } + DeclareCastMethod(DependenceLine); + DeclareCastMethod(DependenceDistance); + DeclareCastMethod(DependencePoint); + DeclareCastMethod(DependenceNone); + DeclareCastMethod(DependenceEmpty); +#undef DeclareCastMethod + + protected: + const Loop* loop_; +}; + +class DependenceLine : public Constraint { + public: + DependenceLine(SENode* a, SENode* b, SENode* c, const Loop* loop) + : Constraint(loop), a_(a), b_(b), c_(c) {} + + ConstraintType GetType() const final { return Line; } + + DependenceLine* AsDependenceLine() final { return this; } + const DependenceLine* AsDependenceLine() const final { return this; } + + SENode* GetA() const { return a_; } + SENode* GetB() const { return b_; } + SENode* GetC() const { return c_; } + + private: + SENode* a_; + SENode* b_; + SENode* c_; +}; + +class DependenceDistance : public Constraint { + public: + DependenceDistance(SENode* distance, const Loop* loop) + : Constraint(loop), distance_(distance) {} + + ConstraintType GetType() const final { return Distance; } + + DependenceDistance* AsDependenceDistance() final { return this; } + const DependenceDistance* AsDependenceDistance() const final { return this; } + + SENode* GetDistance() const { return distance_; } + + private: + SENode* distance_; +}; + +class DependencePoint : public Constraint { + public: + DependencePoint(SENode* source, SENode* destination, const Loop* loop) + : Constraint(loop), source_(source), destination_(destination) {} + + ConstraintType GetType() const final { return Point; } + + DependencePoint* AsDependencePoint() final { return this; } + const DependencePoint* AsDependencePoint() const final { return this; } + + SENode* GetSource() const { return source_; } + SENode* GetDestination() const { return destination_; } + + private: + SENode* source_; + SENode* destination_; +}; + +class DependenceNone : public Constraint { + public: + DependenceNone() : Constraint(nullptr) {} + ConstraintType GetType() const final { return None; } + + DependenceNone* AsDependenceNone() final { return this; } + const DependenceNone* AsDependenceNone() const final { return this; } +}; + +class DependenceEmpty : public Constraint { + public: + DependenceEmpty() : Constraint(nullptr) {} + ConstraintType GetType() const final { return Empty; } + + DependenceEmpty* AsDependenceEmpty() final { return this; } + const DependenceEmpty* AsDependenceEmpty() const final { return this; } +}; + +// Provides dependence information between a store instruction and a load +// instruction inside the same loop in a loop nest. +// +// The analysis can only check dependence between stores and loads with regard +// to the loop nest it is created with. +// +// The analysis can output debugging information to a stream. The output +// describes the control flow of the analysis and what information it can deduce +// at each step. +// SetDebugStream and ClearDebugStream are provided for this functionality. +// +// The dependency algorithm is based on the 1990 Paper +// Practical Dependence Testing +// Gina Goff, Ken Kennedy, Chau-Wen Tseng +// +// The algorithm first identifies subscript pairs between the load and store. +// Each pair is tested until all have been tested or independence is found. +// The number of induction variables in a pair determines which test to perform +// on it; +// Zero Index Variable (ZIV) is used when no induction variables are present +// in the pair. +// Single Index Variable (SIV) is used when only one induction variable is +// present, but may occur multiple times in the pair. +// Multiple Index Variable (MIV) is used when more than one induction variable +// is present in the pair. +class LoopDependenceAnalysis { + public: + LoopDependenceAnalysis(IRContext* context, std::vector loops) + : context_(context), + loops_(loops), + scalar_evolution_(context), + debug_stream_(nullptr), + constraints_{} {} + + // Finds the dependence between |source| and |destination|. + // |source| should be an OpLoad. + // |destination| should be an OpStore. + // Any direction and distance information found will be stored in + // |distance_vector|. + // Returns true if independence is found, false otherwise. + bool GetDependence(const Instruction* source, const Instruction* destination, + DistanceVector* distance_vector); + + // Returns true if |subscript_pair| represents a Zero Index Variable pair + // (ZIV) + bool IsZIV(const std::pair& subscript_pair); + + // Returns true if |subscript_pair| represents a Single Index Variable + // (SIV) pair + bool IsSIV(const std::pair& subscript_pair); + + // Returns true if |subscript_pair| represents a Multiple Index Variable + // (MIV) pair + bool IsMIV(const std::pair& subscript_pair); + + // Finds the lower bound of |loop| as an SENode* and returns the result. + // The lower bound is the starting value of the loops induction variable + SENode* GetLowerBound(const Loop* loop); + + // Finds the upper bound of |loop| as an SENode* and returns the result. + // The upper bound is the last value before the loop exit condition is met. + SENode* GetUpperBound(const Loop* loop); + + // Returns true if |value| is between |bound_one| and |bound_two| (inclusive). + bool IsWithinBounds(int64_t value, int64_t bound_one, int64_t bound_two); + + // Finds the bounds of |loop| as upper_bound - lower_bound and returns the + // resulting SENode. + // If the operations can not be completed a nullptr is returned. + SENode* GetTripCount(const Loop* loop); + + // Returns the SENode* produced by building an SENode from the result of + // calling GetInductionInitValue on |loop|. + // If the operation can not be completed a nullptr is returned. + SENode* GetFirstTripInductionNode(const Loop* loop); + + // Returns the SENode* produced by building an SENode from the result of + // GetFirstTripInductionNode + (GetTripCount - 1) * induction_coefficient. + // If the operation can not be completed a nullptr is returned. + SENode* GetFinalTripInductionNode(const Loop* loop, + SENode* induction_coefficient); + + // Returns all the distinct loops that appear in |nodes|. + std::set CollectLoops( + const std::vector& nodes); + + // Returns all the distinct loops that appear in |source| and |destination|. + std::set CollectLoops(SENode* source, SENode* destination); + + // Returns true if |distance| is provably outside the loop bounds. + // |coefficient| must be an SENode representing the coefficient of the + // induction variable of |loop|. + // This method is able to handle some symbolic cases which IsWithinBounds + // can't handle. + bool IsProvablyOutsideOfLoopBounds(const Loop* loop, SENode* distance, + SENode* coefficient); + + // Sets the ostream for debug information for the analysis. + void SetDebugStream(std::ostream& debug_stream) { + debug_stream_ = &debug_stream; + } + + // Clears the stored ostream to stop debug information printing. + void ClearDebugStream() { debug_stream_ = nullptr; } + + // Returns the ScalarEvolutionAnalysis used by this analysis. + ScalarEvolutionAnalysis* GetScalarEvolution() { return &scalar_evolution_; } + + // Creates a new constraint of type |T| and returns the pointer to it. + template + Constraint* make_constraint(Args&&... args) { + constraints_.push_back( + std::unique_ptr(new T(std::forward(args)...))); + + return constraints_.back().get(); + } + + // Subscript partitioning as described in Figure 1 of 'Practical Dependence + // Testing' by Gina Goff, Ken Kennedy, and Chau-Wen Tseng from PLDI '91. + // Partitions the subscripts into independent subscripts and minimally coupled + // sets of subscripts. + // Returns the partitioning of subscript pairs. Sets of size 1 indicates an + // independent subscript-pair and others indicate coupled sets. + using PartitionedSubscripts = + std::vector>>; + PartitionedSubscripts PartitionSubscripts( + const std::vector& source_subscripts, + const std::vector& destination_subscripts); + + // Returns the Loop* matching the loop for |subscript_pair|. + // |subscript_pair| must be an SIV pair. + const Loop* GetLoopForSubscriptPair( + const std::pair& subscript_pair); + + // Returns the DistanceEntry matching the loop for |subscript_pair|. + // |subscript_pair| must be an SIV pair. + DistanceEntry* GetDistanceEntryForSubscriptPair( + const std::pair& subscript_pair, + DistanceVector* distance_vector); + + // Returns the DistanceEntry matching |loop|. + DistanceEntry* GetDistanceEntryForLoop(const Loop* loop, + DistanceVector* distance_vector); + + // Returns a vector of Instruction* which form the subscripts of the array + // access defined by the access chain |instruction|. + std::vector GetSubscripts(const Instruction* instruction); + + // Delta test as described in Figure 3 of 'Practical Dependence + // Testing' by Gina Goff, Ken Kennedy, and Chau-Wen Tseng from PLDI '91. + bool DeltaTest( + const std::vector>& coupled_subscripts, + DistanceVector* dv_entry); + + // Constraint propagation as described in Figure 5 of 'Practical Dependence + // Testing' by Gina Goff, Ken Kennedy, and Chau-Wen Tseng from PLDI '91. + std::pair PropagateConstraints( + const std::pair& subscript_pair, + const std::vector& constraints); + + // Constraint intersection as described in Figure 4 of 'Practical Dependence + // Testing' by Gina Goff, Ken Kennedy, and Chau-Wen Tseng from PLDI '91. + Constraint* IntersectConstraints(Constraint* constraint_0, + Constraint* constraint_1, + const SENode* lower_bound, + const SENode* upper_bound); + + // Returns true if each loop in |loops| is in a form supported by this + // analysis. + // A loop is supported if it has a single induction variable and that + // induction variable has a step of +1 or -1 per loop iteration. + bool CheckSupportedLoops(std::vector loops); + + // Returns true if |loop| is in a form supported by this analysis. + // A loop is supported if it has a single induction variable and that + // induction variable has a step of +1 or -1 per loop iteration. + bool IsSupportedLoop(const Loop* loop); + + private: + IRContext* context_; + + // The loop nest we are analysing the dependence of. + std::vector loops_; + + // The ScalarEvolutionAnalysis used by this analysis to store and perform much + // of its logic. + ScalarEvolutionAnalysis scalar_evolution_; + + // The ostream debug information for the analysis to print to. + std::ostream* debug_stream_; + + // Stores all the constraints created by the analysis. + std::list> constraints_; + + // Returns true if independence can be proven and false if it can't be proven. + bool ZIVTest(const std::pair& subscript_pair); + + // Analyzes the subscript pair to find an applicable SIV test. + // Returns true if independence can be proven and false if it can't be proven. + bool SIVTest(const std::pair& subscript_pair, + DistanceVector* distance_vector); + + // Takes the form a*i + c1, a*i + c2 + // When c1 and c2 are loop invariant and a is constant + // distance = (c1 - c2)/a + // < if distance > 0 + // direction = = if distance = 0 + // > if distance < 0 + // Returns true if independence is proven and false if it can't be proven. + bool StrongSIVTest(SENode* source, SENode* destination, SENode* coeff, + DistanceEntry* distance_entry); + + // Takes for form a*i + c1, a*i + c2 + // where c1 and c2 are loop invariant and a is constant. + // c1 and/or c2 contain one or more SEValueUnknown nodes. + bool SymbolicStrongSIVTest(SENode* source, SENode* destination, + SENode* coefficient, + DistanceEntry* distance_entry); + + // Takes the form a1*i + c1, a2*i + c2 + // where a1 = 0 + // distance = (c1 - c2) / a2 + // Returns true if independence is proven and false if it can't be proven. + bool WeakZeroSourceSIVTest(SENode* source, SERecurrentNode* destination, + SENode* coefficient, + DistanceEntry* distance_entry); + + // Takes the form a1*i + c1, a2*i + c2 + // where a2 = 0 + // distance = (c2 - c1) / a1 + // Returns true if independence is proven and false if it can't be proven. + bool WeakZeroDestinationSIVTest(SERecurrentNode* source, SENode* destination, + SENode* coefficient, + DistanceEntry* distance_entry); + + // Takes the form a1*i + c1, a2*i + c2 + // where a1 = -a2 + // distance = (c2 - c1) / 2*a1 + // Returns true if independence is proven and false if it can't be proven. + bool WeakCrossingSIVTest(SENode* source, SENode* destination, + SENode* coefficient, DistanceEntry* distance_entry); + + // Uses the def_use_mgr to get the instruction referenced by + // SingleWordInOperand(|id|) when called on |instruction|. + Instruction* GetOperandDefinition(const Instruction* instruction, int id); + + // Perform the GCD test if both, the source and the destination nodes, are in + // the form a0*i0 + a1*i1 + ... an*in + c. + bool GCDMIVTest(const std::pair& subscript_pair); + + // Finds the number of induction variables in |node|. + // Returns -1 on failure. + int64_t CountInductionVariables(SENode* node); + + // Finds the number of induction variables shared between |source| and + // |destination|. + // Returns -1 on failure. + int64_t CountInductionVariables(SENode* source, SENode* destination); + + // Takes the offset from the induction variable and subtracts the lower bound + // from it to get the constant term added to the induction. + // Returns the resuting constant term, or nullptr if it could not be produced. + SENode* GetConstantTerm(const Loop* loop, SERecurrentNode* induction); + + // Marks all the distance entries in |distance_vector| that were relate to + // loops in |loops_| but were not used in any subscripts as irrelevant to the + // to the dependence test. + void MarkUnsusedDistanceEntriesAsIrrelevant(const Instruction* source, + const Instruction* destination, + DistanceVector* distance_vector); + + // Converts |value| to a std::string and returns the result. + // This is required because Android does not compile std::to_string. + template + std::string ToString(valueT value) { + std::ostringstream string_stream; + string_stream << value; + return string_stream.str(); + } + + // Prints |debug_msg| and "\n" to the ostream pointed to by |debug_stream_|. + // Won't print anything if |debug_stream_| is nullptr. + void PrintDebug(std::string debug_msg); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_DEPENDENCE_H_ diff --git a/3rdparty/spirv-tools/source/opt/loop_dependence_helpers.cpp b/3rdparty/spirv-tools/source/opt/loop_dependence_helpers.cpp new file mode 100644 index 000000000..de27a0a72 --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/loop_dependence_helpers.cpp @@ -0,0 +1,541 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/loop_dependence.h" + +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/instruction.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/scalar_analysis_nodes.h" + +namespace spvtools { +namespace opt { + +bool LoopDependenceAnalysis::IsZIV( + const std::pair& subscript_pair) { + return CountInductionVariables(subscript_pair.first, subscript_pair.second) == + 0; +} + +bool LoopDependenceAnalysis::IsSIV( + const std::pair& subscript_pair) { + return CountInductionVariables(subscript_pair.first, subscript_pair.second) == + 1; +} + +bool LoopDependenceAnalysis::IsMIV( + const std::pair& subscript_pair) { + return CountInductionVariables(subscript_pair.first, subscript_pair.second) > + 1; +} + +SENode* LoopDependenceAnalysis::GetLowerBound(const Loop* loop) { + Instruction* cond_inst = loop->GetConditionInst(); + if (!cond_inst) { + return nullptr; + } + Instruction* lower_inst = GetOperandDefinition(cond_inst, 0); + switch (cond_inst->opcode()) { + case SpvOpULessThan: + case SpvOpSLessThan: + case SpvOpULessThanEqual: + case SpvOpSLessThanEqual: + case SpvOpUGreaterThan: + case SpvOpSGreaterThan: + case SpvOpUGreaterThanEqual: + case SpvOpSGreaterThanEqual: { + // If we have a phi we are looking at the induction variable. We look + // through the phi to the initial value of the phi upon entering the loop. + if (lower_inst->opcode() == SpvOpPhi) { + lower_inst = GetOperandDefinition(lower_inst, 0); + // We don't handle looking through multiple phis. + if (lower_inst->opcode() == SpvOpPhi) { + return nullptr; + } + } + return scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(lower_inst)); + } + default: + return nullptr; + } +} + +SENode* LoopDependenceAnalysis::GetUpperBound(const Loop* loop) { + Instruction* cond_inst = loop->GetConditionInst(); + if (!cond_inst) { + return nullptr; + } + Instruction* upper_inst = GetOperandDefinition(cond_inst, 1); + switch (cond_inst->opcode()) { + case SpvOpULessThan: + case SpvOpSLessThan: { + // When we have a < condition we must subtract 1 from the analyzed upper + // instruction. + SENode* upper_bound = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction( + scalar_evolution_.AnalyzeInstruction(upper_inst), + scalar_evolution_.CreateConstant(1))); + return upper_bound; + } + case SpvOpUGreaterThan: + case SpvOpSGreaterThan: { + // When we have a > condition we must add 1 to the analyzed upper + // instruction. + SENode* upper_bound = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateAddNode( + scalar_evolution_.AnalyzeInstruction(upper_inst), + scalar_evolution_.CreateConstant(1))); + return upper_bound; + } + case SpvOpULessThanEqual: + case SpvOpSLessThanEqual: + case SpvOpUGreaterThanEqual: + case SpvOpSGreaterThanEqual: { + // We don't need to modify the results of analyzing when we have <= or >=. + SENode* upper_bound = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(upper_inst)); + return upper_bound; + } + default: + return nullptr; + } +} + +bool LoopDependenceAnalysis::IsWithinBounds(int64_t value, int64_t bound_one, + int64_t bound_two) { + if (bound_one < bound_two) { + // If |bound_one| is the lower bound. + return (value >= bound_one && value <= bound_two); + } else if (bound_one > bound_two) { + // If |bound_two| is the lower bound. + return (value >= bound_two && value <= bound_one); + } else { + // Both bounds have the same value. + return value == bound_one; + } +} + +bool LoopDependenceAnalysis::IsProvablyOutsideOfLoopBounds( + const Loop* loop, SENode* distance, SENode* coefficient) { + // We test to see if we can reduce the coefficient to an integral constant. + SEConstantNode* coefficient_constant = coefficient->AsSEConstantNode(); + if (!coefficient_constant) { + PrintDebug( + "IsProvablyOutsideOfLoopBounds could not reduce coefficient to a " + "SEConstantNode so must exit."); + return false; + } + + SENode* lower_bound = GetLowerBound(loop); + SENode* upper_bound = GetUpperBound(loop); + if (!lower_bound || !upper_bound) { + PrintDebug( + "IsProvablyOutsideOfLoopBounds could not get both the lower and upper " + "bounds so must exit."); + return false; + } + // If the coefficient is positive we calculate bounds as upper - lower + // If the coefficient is negative we calculate bounds as lower - upper + SENode* bounds = nullptr; + if (coefficient_constant->FoldToSingleValue() >= 0) { + PrintDebug( + "IsProvablyOutsideOfLoopBounds found coefficient >= 0.\n" + "Using bounds as upper - lower."); + bounds = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(upper_bound, lower_bound)); + } else { + PrintDebug( + "IsProvablyOutsideOfLoopBounds found coefficient < 0.\n" + "Using bounds as lower - upper."); + bounds = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(lower_bound, upper_bound)); + } + + // We can attempt to deal with symbolic cases by subtracting |distance| and + // the bound nodes. If we can subtract, simplify and produce a SEConstantNode + // we can produce some information. + SEConstantNode* distance_minus_bounds = + scalar_evolution_ + .SimplifyExpression( + scalar_evolution_.CreateSubtraction(distance, bounds)) + ->AsSEConstantNode(); + if (distance_minus_bounds) { + PrintDebug( + "IsProvablyOutsideOfLoopBounds found distance - bounds as a " + "SEConstantNode with value " + + ToString(distance_minus_bounds->FoldToSingleValue())); + // If distance - bounds > 0 we prove the distance is outwith the loop + // bounds. + if (distance_minus_bounds->FoldToSingleValue() > 0) { + PrintDebug( + "IsProvablyOutsideOfLoopBounds found distance escaped the loop " + "bounds."); + return true; + } + } + + return false; +} + +const Loop* LoopDependenceAnalysis::GetLoopForSubscriptPair( + const std::pair& subscript_pair) { + // Collect all the SERecurrentNodes. + std::vector source_nodes = + std::get<0>(subscript_pair)->CollectRecurrentNodes(); + std::vector destination_nodes = + std::get<1>(subscript_pair)->CollectRecurrentNodes(); + + // Collect all the loops stored by the SERecurrentNodes. + std::unordered_set loops{}; + for (auto source_nodes_it = source_nodes.begin(); + source_nodes_it != source_nodes.end(); ++source_nodes_it) { + loops.insert((*source_nodes_it)->GetLoop()); + } + for (auto destination_nodes_it = destination_nodes.begin(); + destination_nodes_it != destination_nodes.end(); + ++destination_nodes_it) { + loops.insert((*destination_nodes_it)->GetLoop()); + } + + // If we didn't find 1 loop |subscript_pair| is a subscript over multiple or 0 + // loops. We don't handle this so return nullptr. + if (loops.size() != 1) { + PrintDebug("GetLoopForSubscriptPair found loops.size() != 1."); + return nullptr; + } + return *loops.begin(); +} + +DistanceEntry* LoopDependenceAnalysis::GetDistanceEntryForLoop( + const Loop* loop, DistanceVector* distance_vector) { + if (!loop) { + return nullptr; + } + + DistanceEntry* distance_entry = nullptr; + for (size_t loop_index = 0; loop_index < loops_.size(); ++loop_index) { + if (loop == loops_[loop_index]) { + distance_entry = &(distance_vector->GetEntries()[loop_index]); + break; + } + } + + return distance_entry; +} + +DistanceEntry* LoopDependenceAnalysis::GetDistanceEntryForSubscriptPair( + const std::pair& subscript_pair, + DistanceVector* distance_vector) { + const Loop* loop = GetLoopForSubscriptPair(subscript_pair); + + return GetDistanceEntryForLoop(loop, distance_vector); +} + +SENode* LoopDependenceAnalysis::GetTripCount(const Loop* loop) { + BasicBlock* condition_block = loop->FindConditionBlock(); + if (!condition_block) { + return nullptr; + } + Instruction* induction_instr = loop->FindConditionVariable(condition_block); + if (!induction_instr) { + return nullptr; + } + Instruction* cond_instr = loop->GetConditionInst(); + if (!cond_instr) { + return nullptr; + } + + size_t iteration_count = 0; + + // We have to check the instruction type here. If the condition instruction + // isn't a supported type we can't calculate the trip count. + if (loop->IsSupportedCondition(cond_instr->opcode())) { + if (loop->FindNumberOfIterations(induction_instr, &*condition_block->tail(), + &iteration_count)) { + return scalar_evolution_.CreateConstant( + static_cast(iteration_count)); + } + } + + return nullptr; +} + +SENode* LoopDependenceAnalysis::GetFirstTripInductionNode(const Loop* loop) { + BasicBlock* condition_block = loop->FindConditionBlock(); + if (!condition_block) { + return nullptr; + } + Instruction* induction_instr = loop->FindConditionVariable(condition_block); + if (!induction_instr) { + return nullptr; + } + int64_t induction_initial_value = 0; + if (!loop->GetInductionInitValue(induction_instr, &induction_initial_value)) { + return nullptr; + } + + SENode* induction_init_SENode = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateConstant(induction_initial_value)); + return induction_init_SENode; +} + +SENode* LoopDependenceAnalysis::GetFinalTripInductionNode( + const Loop* loop, SENode* induction_coefficient) { + SENode* first_trip_induction_node = GetFirstTripInductionNode(loop); + if (!first_trip_induction_node) { + return nullptr; + } + // Get trip_count as GetTripCount - 1 + // This is because the induction variable is not stepped on the first + // iteration of the loop + SENode* trip_count = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateSubtraction( + GetTripCount(loop), scalar_evolution_.CreateConstant(1))); + // Return first_trip_induction_node + trip_count * induction_coefficient + return scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateAddNode( + first_trip_induction_node, + scalar_evolution_.CreateMultiplyNode(trip_count, induction_coefficient))); +} + +std::set LoopDependenceAnalysis::CollectLoops( + const std::vector& recurrent_nodes) { + // We don't handle loops with more than one induction variable. Therefore we + // can identify the number of induction variables by collecting all of the + // loops the collected recurrent nodes belong to. + std::set loops{}; + for (auto recurrent_nodes_it = recurrent_nodes.begin(); + recurrent_nodes_it != recurrent_nodes.end(); ++recurrent_nodes_it) { + loops.insert((*recurrent_nodes_it)->GetLoop()); + } + + return loops; +} + +int64_t LoopDependenceAnalysis::CountInductionVariables(SENode* node) { + if (!node) { + return -1; + } + + std::vector recurrent_nodes = node->CollectRecurrentNodes(); + + // We don't handle loops with more than one induction variable. Therefore we + // can identify the number of induction variables by collecting all of the + // loops the collected recurrent nodes belong to. + std::set loops = CollectLoops(recurrent_nodes); + + return static_cast(loops.size()); +} + +std::set LoopDependenceAnalysis::CollectLoops( + SENode* source, SENode* destination) { + if (!source || !destination) { + return std::set{}; + } + + std::vector source_nodes = source->CollectRecurrentNodes(); + std::vector destination_nodes = + destination->CollectRecurrentNodes(); + + std::set loops = CollectLoops(source_nodes); + std::set destination_loops = CollectLoops(destination_nodes); + + loops.insert(std::begin(destination_loops), std::end(destination_loops)); + + return loops; +} + +int64_t LoopDependenceAnalysis::CountInductionVariables(SENode* source, + SENode* destination) { + if (!source || !destination) { + return -1; + } + + std::set loops = CollectLoops(source, destination); + + return static_cast(loops.size()); +} + +Instruction* LoopDependenceAnalysis::GetOperandDefinition( + const Instruction* instruction, int id) { + return context_->get_def_use_mgr()->GetDef( + instruction->GetSingleWordInOperand(id)); +} + +std::vector LoopDependenceAnalysis::GetSubscripts( + const Instruction* instruction) { + Instruction* access_chain = GetOperandDefinition(instruction, 0); + + std::vector subscripts; + + for (auto i = 1u; i < access_chain->NumInOperandWords(); ++i) { + subscripts.push_back(GetOperandDefinition(access_chain, i)); + } + + return subscripts; +} + +SENode* LoopDependenceAnalysis::GetConstantTerm(const Loop* loop, + SERecurrentNode* induction) { + SENode* offset = induction->GetOffset(); + SENode* lower_bound = GetLowerBound(loop); + if (!offset || !lower_bound) { + return nullptr; + } + SENode* constant_term = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(offset, lower_bound)); + return constant_term; +} + +bool LoopDependenceAnalysis::CheckSupportedLoops( + std::vector loops) { + for (auto loop : loops) { + if (!IsSupportedLoop(loop)) { + return false; + } + } + return true; +} + +void LoopDependenceAnalysis::MarkUnsusedDistanceEntriesAsIrrelevant( + const Instruction* source, const Instruction* destination, + DistanceVector* distance_vector) { + std::vector source_subscripts = GetSubscripts(source); + std::vector destination_subscripts = GetSubscripts(destination); + + std::set used_loops{}; + + for (Instruction* source_inst : source_subscripts) { + SENode* source_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(source_inst)); + std::vector recurrent_nodes = + source_node->CollectRecurrentNodes(); + for (SERecurrentNode* recurrent_node : recurrent_nodes) { + used_loops.insert(recurrent_node->GetLoop()); + } + } + + for (Instruction* destination_inst : destination_subscripts) { + SENode* destination_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(destination_inst)); + std::vector recurrent_nodes = + destination_node->CollectRecurrentNodes(); + for (SERecurrentNode* recurrent_node : recurrent_nodes) { + used_loops.insert(recurrent_node->GetLoop()); + } + } + + for (size_t i = 0; i < loops_.size(); ++i) { + if (used_loops.find(loops_[i]) == used_loops.end()) { + distance_vector->GetEntries()[i].dependence_information = + DistanceEntry::DependenceInformation::IRRELEVANT; + } + } +} + +bool LoopDependenceAnalysis::IsSupportedLoop(const Loop* loop) { + std::vector inductions{}; + loop->GetInductionVariables(inductions); + if (inductions.size() != 1) { + return false; + } + Instruction* induction = inductions[0]; + SENode* induction_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(induction)); + if (!induction_node->AsSERecurrentNode()) { + return false; + } + SENode* induction_step = + induction_node->AsSERecurrentNode()->GetCoefficient(); + if (!induction_step->AsSEConstantNode()) { + return false; + } + if (!(induction_step->AsSEConstantNode()->FoldToSingleValue() == 1 || + induction_step->AsSEConstantNode()->FoldToSingleValue() == -1)) { + return false; + } + return true; +} + +void LoopDependenceAnalysis::PrintDebug(std::string debug_msg) { + if (debug_stream_) { + (*debug_stream_) << debug_msg << "\n"; + } +} + +bool Constraint::operator==(const Constraint& other) const { + // A distance of |d| is equivalent to a line |x - y = -d| + if ((GetType() == ConstraintType::Distance && + other.GetType() == ConstraintType::Line) || + (GetType() == ConstraintType::Line && + other.GetType() == ConstraintType::Distance)) { + auto is_distance = AsDependenceLine() != nullptr; + + auto as_distance = + is_distance ? AsDependenceDistance() : other.AsDependenceDistance(); + auto distance = as_distance->GetDistance(); + + auto line = other.AsDependenceLine(); + + auto scalar_evolution = distance->GetParentAnalysis(); + + auto neg_distance = scalar_evolution->SimplifyExpression( + scalar_evolution->CreateNegation(distance)); + + return *scalar_evolution->CreateConstant(1) == *line->GetA() && + *scalar_evolution->CreateConstant(-1) == *line->GetB() && + *neg_distance == *line->GetC(); + } + + if (GetType() != other.GetType()) { + return false; + } + + if (AsDependenceDistance()) { + return *AsDependenceDistance()->GetDistance() == + *other.AsDependenceDistance()->GetDistance(); + } + + if (AsDependenceLine()) { + auto this_line = AsDependenceLine(); + auto other_line = other.AsDependenceLine(); + return *this_line->GetA() == *other_line->GetA() && + *this_line->GetB() == *other_line->GetB() && + *this_line->GetC() == *other_line->GetC(); + } + + if (AsDependencePoint()) { + auto this_point = AsDependencePoint(); + auto other_point = other.AsDependencePoint(); + + return *this_point->GetSource() == *other_point->GetSource() && + *this_point->GetDestination() == *other_point->GetDestination(); + } + + return true; +} + +bool Constraint::operator!=(const Constraint& other) const { + return !(*this == other); +} + +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/loop_descriptor.cpp b/3rdparty/spirv-tools/source/opt/loop_descriptor.cpp index 0a4ed9250..efc56bdba 100644 --- a/3rdparty/spirv-tools/source/opt/loop_descriptor.cpp +++ b/3rdparty/spirv-tools/source/opt/loop_descriptor.cpp @@ -12,41 +12,44 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "opt/loop_descriptor.h" +#include "source/opt/loop_descriptor.h" + #include #include +#include +#include #include #include #include -#include "constants.h" -#include "opt/cfg.h" -#include "opt/dominator_tree.h" -#include "opt/ir_builder.h" -#include "opt/ir_context.h" -#include "opt/iterator.h" -#include "opt/make_unique.h" -#include "opt/tree_iterator.h" +#include "source/opt/cfg.h" +#include "source/opt/constants.h" +#include "source/opt/dominator_tree.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" +#include "source/opt/tree_iterator.h" +#include "source/util/make_unique.h" namespace spvtools { -namespace ir { +namespace opt { // Takes in a phi instruction |induction| and the loop |header| and returns the // step operation of the loop. -ir::Instruction* Loop::GetInductionStepOperation( - const ir::Instruction* induction) const { +Instruction* Loop::GetInductionStepOperation( + const Instruction* induction) const { // Induction must be a phi instruction. assert(induction->opcode() == SpvOpPhi); - ir::Instruction* step = nullptr; + Instruction* step = nullptr; - opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); // Traverse the incoming operands of the phi instruction. for (uint32_t operand_id = 1; operand_id < induction->NumInOperands(); operand_id += 2) { // Incoming edge. - ir::BasicBlock* incoming_block = + BasicBlock* incoming_block = context_->cfg()->block(induction->GetSingleWordInOperand(operand_id)); // Check if the block is dominated by header, and thus coming from within @@ -142,17 +145,36 @@ int64_t Loop::GetResidualConditionValue(SpvOp condition, int64_t initial_value, return remainder; } +Instruction* Loop::GetConditionInst() const { + BasicBlock* condition_block = FindConditionBlock(); + if (!condition_block) { + return nullptr; + } + Instruction* branch_conditional = &*condition_block->tail(); + if (!branch_conditional || + branch_conditional->opcode() != SpvOpBranchConditional) { + return nullptr; + } + Instruction* condition_inst = context_->get_def_use_mgr()->GetDef( + branch_conditional->GetSingleWordInOperand(0)); + if (IsSupportedCondition(condition_inst->opcode())) { + return condition_inst; + } + + return nullptr; +} + // Extract the initial value from the |induction| OpPhi instruction and store it // in |value|. If the function couldn't find the initial value of |induction| // return false. -bool Loop::GetInductionInitValue(const ir::Instruction* induction, +bool Loop::GetInductionInitValue(const Instruction* induction, int64_t* value) const { - ir::Instruction* constant_instruction = nullptr; - opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + Instruction* constant_instruction = nullptr; + analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); for (uint32_t operand_id = 0; operand_id < induction->NumInOperands(); operand_id += 2) { - ir::BasicBlock* bb = context_->cfg()->block( + BasicBlock* bb = context_->cfg()->block( induction->GetSingleWordInOperand(operand_id + 1)); if (!IsInsideLoop(bb)) { @@ -163,13 +185,13 @@ bool Loop::GetInductionInitValue(const ir::Instruction* induction, if (!constant_instruction) return false; - const opt::analysis::Constant* constant = + const analysis::Constant* constant = context_->get_constant_mgr()->FindDeclaredConstant( constant_instruction->result_id()); if (!constant) return false; if (value) { - const opt::analysis::Integer* type = + const analysis::Integer* type = constant->AsIntConstant()->type()->AsInteger(); if (type->IsSigned()) { @@ -182,7 +204,7 @@ bool Loop::GetInductionInitValue(const ir::Instruction* induction, return true; } -Loop::Loop(IRContext* context, opt::DominatorAnalysis* dom_analysis, +Loop::Loop(IRContext* context, DominatorAnalysis* dom_analysis, BasicBlock* header, BasicBlock* continue_target, BasicBlock* merge_target) : context_(context), @@ -195,19 +217,20 @@ Loop::Loop(IRContext* context, opt::DominatorAnalysis* dom_analysis, assert(context); assert(dom_analysis); loop_preheader_ = FindLoopPreheader(dom_analysis); + loop_latch_ = FindLatchBlock(); } -BasicBlock* Loop::FindLoopPreheader(opt::DominatorAnalysis* dom_analysis) { +BasicBlock* Loop::FindLoopPreheader(DominatorAnalysis* dom_analysis) { CFG* cfg = context_->cfg(); - opt::DominatorTree& dom_tree = dom_analysis->GetDomTree(); - opt::DominatorTreeNode* header_node = dom_tree.GetTreeNode(loop_header_); + DominatorTree& dom_tree = dom_analysis->GetDomTree(); + DominatorTreeNode* header_node = dom_tree.GetTreeNode(loop_header_); // The loop predecessor. BasicBlock* loop_pred = nullptr; auto header_pred = cfg->preds(loop_header_->id()); for (uint32_t p_id : header_pred) { - opt::DominatorTreeNode* node = dom_tree.GetTreeNode(p_id); + DominatorTreeNode* node = dom_tree.GetTreeNode(p_id); if (node && !dom_tree.Dominates(header_node, node)) { // The predecessor is not part of the loop, so potential loop preheader. if (loop_pred && node->bb_ != loop_pred) { @@ -244,8 +267,8 @@ bool Loop::IsInsideLoop(Instruction* inst) const { bool Loop::IsBasicBlockInLoopSlow(const BasicBlock* bb) { assert(bb->GetParent() && "The basic block does not belong to a function"); - opt::DominatorAnalysis* dom_analysis = - context_->GetDominatorAnalysis(bb->GetParent(), *context_->cfg()); + DominatorAnalysis* dom_analysis = + context_->GetDominatorAnalysis(bb->GetParent()); if (dom_analysis->IsReachable(bb) && !dom_analysis->Dominates(GetHeaderBlock(), bb)) return false; @@ -261,6 +284,11 @@ BasicBlock* Loop::GetOrCreatePreHeaderBlock() { return loop_preheader_; } +void Loop::SetContinueBlock(BasicBlock* continue_block) { + assert(IsInsideLoop(continue_block)); + loop_continue_ = continue_block; +} + void Loop::SetLatchBlock(BasicBlock* latch) { #ifndef NDEBUG assert(latch->GetParent() && "The basic block does not belong to a function"); @@ -302,12 +330,34 @@ void Loop::SetPreHeaderBlock(BasicBlock* preheader) { loop_preheader_ = preheader; } +BasicBlock* Loop::FindLatchBlock() { + CFG* cfg = context_->cfg(); + + DominatorAnalysis* dominator_analysis = + context_->GetDominatorAnalysis(loop_header_->GetParent()); + + // Look at the predecessors of the loop header to find a predecessor block + // which is dominated by the loop continue target. There should only be one + // block which meets this criteria and this is the latch block, as per the + // SPIR-V spec. + for (uint32_t block_id : cfg->preds(loop_header_->id())) { + if (dominator_analysis->Dominates(loop_continue_->id(), block_id)) { + return cfg->block(block_id); + } + } + + assert( + false && + "Every loop should have a latch block dominated by the continue target"); + return nullptr; +} + void Loop::GetExitBlocks(std::unordered_set* exit_blocks) const { - ir::CFG* cfg = context_->cfg(); + CFG* cfg = context_->cfg(); exit_blocks->clear(); for (uint32_t bb_id : GetBlocks()) { - const spvtools::ir::BasicBlock* bb = cfg->block(bb_id); + const BasicBlock* bb = cfg->block(bb_id); bb->ForEachSuccessorLabel([exit_blocks, this](uint32_t succ) { if (!IsInsideLoop(succ)) { exit_blocks->insert(succ); @@ -319,13 +369,13 @@ void Loop::GetExitBlocks(std::unordered_set* exit_blocks) const { void Loop::GetMergingBlocks( std::unordered_set* merging_blocks) const { assert(GetMergeBlock() && "This loop is not structured"); - ir::CFG* cfg = context_->cfg(); + CFG* cfg = context_->cfg(); merging_blocks->clear(); - std::stack to_visit; + std::stack to_visit; to_visit.push(GetMergeBlock()); while (!to_visit.empty()) { - const ir::BasicBlock* bb = to_visit.top(); + const BasicBlock* bb = to_visit.top(); to_visit.pop(); merging_blocks->insert(bb->id()); for (uint32_t pred_id : cfg->preds(bb->id())) { @@ -339,7 +389,7 @@ void Loop::GetMergingBlocks( namespace { static inline bool IsBasicBlockSafeToClone(IRContext* context, BasicBlock* bb) { - for (ir::Instruction& inst : *bb) { + for (Instruction& inst : *bb) { if (!inst.IsBranch() && !context->IsCombinatorInstruction(&inst)) return false; } @@ -350,7 +400,7 @@ static inline bool IsBasicBlockSafeToClone(IRContext* context, BasicBlock* bb) { } // namespace bool Loop::IsSafeToClone() const { - ir::CFG& cfg = *context_->cfg(); + CFG& cfg = *context_->cfg(); for (uint32_t bb_id : GetBlocks()) { BasicBlock* bb = cfg.block(bb_id); @@ -374,14 +424,14 @@ bool Loop::IsSafeToClone() const { } bool Loop::IsLCSSA() const { - ir::CFG* cfg = context_->cfg(); - opt::analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + CFG* cfg = context_->cfg(); + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); std::unordered_set exit_blocks; GetExitBlocks(&exit_blocks); // Declare ir_context so we can capture context_ in the below lambda - ir::IRContext* ir_context = context_; + IRContext* ir_context = context_; for (uint32_t bb_id : GetBlocks()) { for (Instruction& insn : *cfg->block(bb_id)) { @@ -390,7 +440,7 @@ bool Loop::IsLCSSA() const { // - In an exit block and in a phi instruction. if (!def_use_mgr->WhileEachUser( &insn, - [&exit_blocks, ir_context, this](ir::Instruction* use) -> bool { + [&exit_blocks, ir_context, this](Instruction* use) -> bool { BasicBlock* parent = ir_context->get_instr_block(use); assert(parent && "Invalid analysis"); if (IsInsideLoop(parent)) return true; @@ -409,7 +459,7 @@ bool Loop::ShouldHoistInstruction(IRContext* context, Instruction* inst) { } bool Loop::AreAllOperandsOutsideLoop(IRContext* context, Instruction* inst) { - opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); bool all_outside_loop = true; const std::function operand_outside_loop = @@ -425,9 +475,9 @@ bool Loop::AreAllOperandsOutsideLoop(IRContext* context, Instruction* inst) { } void Loop::ComputeLoopStructuredOrder( - std::vector* ordered_loop_blocks, bool include_pre_header, + std::vector* ordered_loop_blocks, bool include_pre_header, bool include_merge) const { - ir::CFG& cfg = *context_->cfg(); + CFG& cfg = *context_->cfg(); // Reserve the memory: all blocks in the loop + extra if needed. ordered_loop_blocks->reserve(GetBlocks().size() + include_pre_header + @@ -443,26 +493,23 @@ void Loop::ComputeLoopStructuredOrder( ordered_loop_blocks->push_back(loop_merge_); } -LoopDescriptor::LoopDescriptor(const Function* f) +LoopDescriptor::LoopDescriptor(IRContext* context, const Function* f) : loops_(), dummy_top_loop_(nullptr) { - PopulateList(f); + PopulateList(context, f); } LoopDescriptor::~LoopDescriptor() { ClearLoops(); } -void LoopDescriptor::PopulateList(const Function* f) { - IRContext* context = f->GetParent()->context(); - - opt::DominatorAnalysis* dom_analysis = - context->GetDominatorAnalysis(f, *context->cfg()); +void LoopDescriptor::PopulateList(IRContext* context, const Function* f) { + DominatorAnalysis* dom_analysis = context->GetDominatorAnalysis(f); ClearLoops(); // Post-order traversal of the dominator tree to find all the OpLoopMerge // instructions. - opt::DominatorTree& dom_tree = dom_analysis->GetDomTree(); - for (opt::DominatorTreeNode& node : - ir::make_range(dom_tree.post_begin(), dom_tree.post_end())) { + DominatorTree& dom_tree = dom_analysis->GetDomTree(); + for (DominatorTreeNode& node : + make_range(dom_tree.post_begin(), dom_tree.post_end())) { Instruction* merge_inst = node.bb_->GetLoopMergeInst(); if (merge_inst) { bool all_backedge_unreachable = true; @@ -516,8 +563,8 @@ void LoopDescriptor::PopulateList(const Function* f) { current_loop->AddNestedLoop(previous_loop); } - opt::DominatorTreeNode* dom_merge_node = dom_tree.GetTreeNode(merge_bb); - for (opt::DominatorTreeNode& loop_node : + DominatorTreeNode* dom_merge_node = dom_tree.GetTreeNode(merge_bb); + for (DominatorTreeNode& loop_node : make_range(node.df_begin(), node.df_end())) { // Check if we are in the loop. if (dom_tree.Dominates(dom_merge_node, &loop_node)) continue; @@ -532,17 +579,55 @@ void LoopDescriptor::PopulateList(const Function* f) { } } -ir::BasicBlock* Loop::FindConditionBlock() const { - const ir::Function& function = *loop_merge_->GetParent(); - ir::BasicBlock* condition_block = nullptr; +std::vector LoopDescriptor::GetLoopsInBinaryLayoutOrder() { + std::vector ids{}; - const opt::DominatorAnalysis* dom_analysis = - context_->GetDominatorAnalysis(&function, *context_->cfg()); - ir::BasicBlock* bb = dom_analysis->ImmediateDominator(loop_merge_); + for (size_t i = 0; i < NumLoops(); ++i) { + ids.push_back(GetLoopByIndex(i).GetHeaderBlock()->id()); + } + + std::vector loops{}; + if (!ids.empty()) { + auto function = GetLoopByIndex(0).GetHeaderBlock()->GetParent(); + for (const auto& block : *function) { + auto block_id = block.id(); + + auto element = std::find(std::begin(ids), std::end(ids), block_id); + if (element != std::end(ids)) { + loops.push_back(&GetLoopByIndex(element - std::begin(ids))); + } + } + } + + return loops; +} + +BasicBlock* Loop::FindConditionBlock() const { + if (!loop_merge_) { + return nullptr; + } + BasicBlock* condition_block = nullptr; + + uint32_t in_loop_pred = 0; + for (uint32_t p : context_->cfg()->preds(loop_merge_->id())) { + if (IsInsideLoop(p)) { + if (in_loop_pred) { + // 2 in-loop predecessors. + return nullptr; + } + in_loop_pred = p; + } + } + if (!in_loop_pred) { + // Merge block is unreachable. + return nullptr; + } + + BasicBlock* bb = context_->cfg()->block(in_loop_pred); if (!bb) return nullptr; - const ir::Instruction& branch = *bb->ctail(); + const Instruction& branch = *bb->ctail(); // Make sure the branch is a conditional branch. if (branch.opcode() != SpvOpBranchConditional) return nullptr; @@ -556,35 +641,39 @@ ir::BasicBlock* Loop::FindConditionBlock() const { return condition_block; } -bool Loop::FindNumberOfIterations(const ir::Instruction* induction, - const ir::Instruction* branch_inst, +bool Loop::FindNumberOfIterations(const Instruction* induction, + const Instruction* branch_inst, size_t* iterations_out, int64_t* step_value_out, int64_t* init_value_out) const { // From the branch instruction find the branch condition. - opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); // Condition instruction from the OpConditionalBranch. - ir::Instruction* condition = + Instruction* condition = def_use_manager->GetDef(branch_inst->GetSingleWordOperand(0)); assert(IsSupportedCondition(condition->opcode())); // Get the constant manager from the ir context. - opt::analysis::ConstantManager* const_manager = context_->get_constant_mgr(); + analysis::ConstantManager* const_manager = context_->get_constant_mgr(); // Find the constant value used by the condition variable. Exit out if it // isn't a constant int. - const opt::analysis::Constant* upper_bound = + const analysis::Constant* upper_bound = const_manager->FindDeclaredConstant(condition->GetSingleWordOperand(3)); if (!upper_bound) return false; // Must be integer because of the opcode on the condition. int64_t condition_value = 0; - const opt::analysis::Integer* type = + const analysis::Integer* type = upper_bound->AsIntConstant()->type()->AsInteger(); + if (type->width() > 32) { + return false; + } + if (type->IsSigned()) { condition_value = upper_bound->AsIntConstant()->GetS32BitValue(); } else { @@ -592,18 +681,18 @@ bool Loop::FindNumberOfIterations(const ir::Instruction* induction, } // Find the instruction which is stepping through the loop. - ir::Instruction* step_inst = GetInductionStepOperation(induction); + Instruction* step_inst = GetInductionStepOperation(induction); if (!step_inst) return false; // Find the constant value used by the condition variable. - const opt::analysis::Constant* step_constant = + const analysis::Constant* step_constant = const_manager->FindDeclaredConstant(step_inst->GetSingleWordOperand(3)); if (!step_constant) return false; // Must be integer because of the opcode on the condition. int64_t step_value = 0; - const opt::analysis::Integer* step_type = + const analysis::Integer* step_type = step_constant->AsIntConstant()->type()->AsInteger(); if (step_type->IsSigned()) { @@ -745,34 +834,34 @@ int64_t Loop::GetIterations(SpvOp condition, int64_t condition_value, // Returns the list of induction variables within the loop. void Loop::GetInductionVariables( - std::vector& induction_variables) const { - for (ir::Instruction& inst : *loop_header_) { + std::vector& induction_variables) const { + for (Instruction& inst : *loop_header_) { if (inst.opcode() == SpvOp::SpvOpPhi) { induction_variables.push_back(&inst); } } } -ir::Instruction* Loop::FindConditionVariable( - const ir::BasicBlock* condition_block) const { +Instruction* Loop::FindConditionVariable( + const BasicBlock* condition_block) const { // Find the branch instruction. - const ir::Instruction& branch_inst = *condition_block->ctail(); + const Instruction& branch_inst = *condition_block->ctail(); - ir::Instruction* induction = nullptr; + Instruction* induction = nullptr; // Verify that the branch instruction is a conditional branch. if (branch_inst.opcode() == SpvOp::SpvOpBranchConditional) { // From the branch instruction find the branch condition. - opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); // Find the instruction representing the condition used in the conditional // branch. - ir::Instruction* condition = + Instruction* condition = def_use_manager->GetDef(branch_inst.GetSingleWordOperand(0)); // Ensure that the condition is a less than operation. if (condition && IsSupportedCondition(condition->opcode())) { // The left hand side operand of the operation. - ir::Instruction* variable_inst = + Instruction* variable_inst = def_use_manager->GetDef(condition->GetSingleWordOperand(2)); // Make sure the variable instruction used is a phi. @@ -792,18 +881,18 @@ ir::Instruction* Loop::FindConditionVariable( uint32_t operand_label_2 = 3; // Make sure one of them is the preheader. - if (variable_inst->GetSingleWordInOperand(operand_label_1) != - loop_preheader_->id() && - variable_inst->GetSingleWordInOperand(operand_label_2) != - loop_preheader_->id()) { + if (!IsInsideLoop( + variable_inst->GetSingleWordInOperand(operand_label_1)) && + !IsInsideLoop( + variable_inst->GetSingleWordInOperand(operand_label_2))) { return nullptr; } // And make sure that the other is the latch block. if (variable_inst->GetSingleWordInOperand(operand_label_1) != - loop_continue_->id() && + loop_latch_->id() && variable_inst->GetSingleWordInOperand(operand_label_2) != - loop_continue_->id()) { + loop_latch_->id()) { return nullptr; } } else { @@ -819,11 +908,24 @@ ir::Instruction* Loop::FindConditionVariable( return induction; } +bool LoopDescriptor::CreatePreHeaderBlocksIfMissing() { + auto modified = false; + + for (auto& loop : *this) { + if (!loop.GetPreHeaderBlock()) { + modified = true; + loop.GetOrCreatePreHeaderBlock(); + } + } + + return modified; +} + // Add and remove loops which have been marked for addition and removal to // maintain the state of the loop descriptor class. void LoopDescriptor::PostModificationCleanup() { LoopContainerType loops_to_remove_; - for (ir::Loop* loop : loops_) { + for (Loop* loop : loops_) { if (loop->IsMarkedForRemoval()) { loops_to_remove_.push_back(loop); if (loop->HasParent()) { @@ -832,13 +934,13 @@ void LoopDescriptor::PostModificationCleanup() { } } - for (ir::Loop* loop : loops_to_remove_) { + for (Loop* loop : loops_to_remove_) { loops_.erase(std::find(loops_.begin(), loops_.end(), loop)); } for (auto& pair : loops_to_add_) { - ir::Loop* parent = pair.first; - ir::Loop* loop = pair.second; + Loop* parent = pair.first; + Loop* loop = pair.second; if (parent) { loop->SetParent(nullptr); @@ -863,12 +965,12 @@ void LoopDescriptor::ClearLoops() { } // Adds a new loop nest to the descriptor set. -ir::Loop* LoopDescriptor::AddLoopNest(std::unique_ptr new_loop) { - ir::Loop* loop = new_loop.release(); +Loop* LoopDescriptor::AddLoopNest(std::unique_ptr new_loop) { + Loop* loop = new_loop.release(); if (!loop->HasParent()) dummy_top_loop_.nested_loops_.push_back(loop); // Iterate from inner to outer most loop, adding basic block to loop mapping // as we go. - for (ir::Loop& current_loop : + for (Loop& current_loop : make_range(iterator::begin(loop), iterator::end(nullptr))) { loops_.push_back(¤t_loop); for (uint32_t bb_id : current_loop.GetBlocks()) @@ -878,18 +980,18 @@ ir::Loop* LoopDescriptor::AddLoopNest(std::unique_ptr new_loop) { return loop; } -void LoopDescriptor::RemoveLoop(ir::Loop* loop) { - ir::Loop* parent = loop->GetParent() ? loop->GetParent() : &dummy_top_loop_; +void LoopDescriptor::RemoveLoop(Loop* loop) { + Loop* parent = loop->GetParent() ? loop->GetParent() : &dummy_top_loop_; parent->nested_loops_.erase(std::find(parent->nested_loops_.begin(), parent->nested_loops_.end(), loop)); std::for_each( loop->nested_loops_.begin(), loop->nested_loops_.end(), - [loop](ir::Loop* sub_loop) { sub_loop->SetParent(loop->GetParent()); }); + [loop](Loop* sub_loop) { sub_loop->SetParent(loop->GetParent()); }); parent->nested_loops_.insert(parent->nested_loops_.end(), loop->nested_loops_.begin(), loop->nested_loops_.end()); for (uint32_t bb_id : loop->GetBlocks()) { - ir::Loop* l = FindLoopForBasicBlock(bb_id); + Loop* l = FindLoopForBasicBlock(bb_id); if (l == loop) { SetBasicBlockToLoop(bb_id, l->GetParent()); } else { @@ -904,5 +1006,5 @@ void LoopDescriptor::RemoveLoop(ir::Loop* loop) { loops_.erase(it); } -} // namespace ir +} // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/loop_descriptor.h b/3rdparty/spirv-tools/source/opt/loop_descriptor.h index 05acce207..45a175a0c 100644 --- a/3rdparty/spirv-tools/source/opt/loop_descriptor.h +++ b/3rdparty/spirv-tools/source/opt/loop_descriptor.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_LOOP_DESCRIPTORS_H_ -#define LIBSPIRV_OPT_LOOP_DESCRIPTORS_H_ +#ifndef SOURCE_OPT_LOOP_DESCRIPTOR_H_ +#define SOURCE_OPT_LOOP_DESCRIPTOR_H_ #include #include @@ -21,18 +21,17 @@ #include #include #include +#include #include -#include "opt/basic_block.h" -#include "opt/module.h" -#include "opt/tree_iterator.h" +#include "source/opt/basic_block.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/module.h" +#include "source/opt/tree_iterator.h" namespace spvtools { namespace opt { -class DominatorAnalysis; -struct DominatorTreeNode; -} // namespace opt -namespace ir { + class IRContext; class CFG; class LoopDescriptor; @@ -53,14 +52,13 @@ class Loop { loop_continue_(nullptr), loop_merge_(nullptr), loop_preheader_(nullptr), + loop_latch_(nullptr), parent_(nullptr), loop_is_marked_for_removal_(false) {} - Loop(IRContext* context, opt::DominatorAnalysis* analysis, BasicBlock* header, + Loop(IRContext* context, DominatorAnalysis* analysis, BasicBlock* header, BasicBlock* continue_target, BasicBlock* merge_target); - ~Loop() {} - // Iterators over the immediate sub-loops. inline iterator begin() { return nested_loops_.begin(); } inline iterator end() { return nested_loops_.end(); } @@ -80,21 +78,31 @@ class Loop { inline void UpdateLoopMergeInst() { assert(GetHeaderBlock()->GetLoopMergeInst() && "The loop is not structured"); - ir::Instruction* merge_inst = GetHeaderBlock()->GetLoopMergeInst(); + Instruction* merge_inst = GetHeaderBlock()->GetLoopMergeInst(); merge_inst->SetInOperand(0, {GetMergeBlock()->id()}); } + // Returns the continue target basic block. This is the block designated as + // the continue target by the OpLoopMerge instruction. + inline BasicBlock* GetContinueBlock() { return loop_continue_; } + inline const BasicBlock* GetContinueBlock() const { return loop_continue_; } + // Returns the latch basic block (basic block that holds the back-edge). // These functions return nullptr if the loop is not structured (i.e. if it // has more than one backedge). - inline BasicBlock* GetLatchBlock() { return loop_continue_; } - inline const BasicBlock* GetLatchBlock() const { return loop_continue_; } + inline BasicBlock* GetLatchBlock() { return loop_latch_; } + inline const BasicBlock* GetLatchBlock() const { return loop_latch_; } + // Sets |latch| as the loop unique block branching back to the header. // A latch block must have the following properties: // - |latch| must be in the loop; // - must be the only block branching back to the header block. void SetLatchBlock(BasicBlock* latch); + // Sets |continue_block| as the continue block of the loop. This should be the + // continue target of the OpLoopMerge and should dominate the latch block. + void SetContinueBlock(BasicBlock* continue_block); + // Returns the basic block which marks the end of the loop. // These functions return nullptr if the loop is not structured. inline BasicBlock* GetMergeBlock() { return loop_merge_; } @@ -155,6 +163,7 @@ class Loop { inline size_t NumImmediateChildren() const { return nested_loops_.size(); } + inline bool HasChildren() const { return !nested_loops_.empty(); } // Adds |nested| as a nested loop of this loop. Automatically register |this| // as the parent of |nested|. inline void AddNestedLoop(Loop* nested) { @@ -224,21 +233,20 @@ class Loop { } // Returns the list of induction variables within the loop. - void GetInductionVariables(std::vector& inductions) const; + void GetInductionVariables(std::vector& inductions) const; // This function uses the |condition| to find the induction variable which is // used by the loop condition within the loop. This only works if the loop is // bound by a single condition and single induction variable. - ir::Instruction* FindConditionVariable(const ir::BasicBlock* condition) const; + Instruction* FindConditionVariable(const BasicBlock* condition) const; // Returns the number of iterations within a loop when given the |induction| // variable and the loop |condition| check. It stores the found number of // iterations in the output parameter |iterations| and optionally, the step // value in |step_value| and the initial value of the induction variable in // |init_value|. - bool FindNumberOfIterations(const ir::Instruction* induction, - const ir::Instruction* condition, - size_t* iterations, + bool FindNumberOfIterations(const Instruction* induction, + const Instruction* condition, size_t* iterations, int64_t* step_amount = nullptr, int64_t* init_value = nullptr) const; @@ -254,7 +262,7 @@ class Loop { // Finds the conditional block with a branch to the merge and continue blocks // within the loop body. - ir::BasicBlock* FindConditionBlock() const; + BasicBlock* FindConditionBlock() const; // Remove the child loop form this loop. inline void RemoveChildLoop(Loop* loop) { @@ -298,13 +306,12 @@ class Loop { // Extract the initial value from the |induction| variable and store it in // |value|. If the function couldn't find the initial value of |induction| // return false. - bool GetInductionInitValue(const ir::Instruction* induction, + bool GetInductionInitValue(const Instruction* induction, int64_t* value) const; // Takes in a phi instruction |induction| and the loop |header| and returns // the step operation of the loop. - ir::Instruction* GetInductionStepOperation( - const ir::Instruction* induction) const; + Instruction* GetInductionStepOperation(const Instruction* induction) const; // Returns true if we can deduce the number of loop iterations in the step // operation |step|. IsSupportedCondition must also be true for the condition @@ -321,9 +328,9 @@ class Loop { // pre-header block will also be included at the beginning of the list if it // exist. If |include_merge| is true, the merge block will also be included at // the end of the list if it exist. - void ComputeLoopStructuredOrder( - std::vector* ordered_loop_blocks, - bool include_pre_header = false, bool include_merge = false) const; + void ComputeLoopStructuredOrder(std::vector* ordered_loop_blocks, + bool include_pre_header = false, + bool include_merge = false) const; // Given the loop |condition|, |initial_value|, |step_value|, the trip count // |number_of_iterations|, and the |unroll_factor| requested, get the new @@ -334,6 +341,19 @@ class Loop { size_t number_of_iterations, size_t unroll_factor); + // Returns the condition instruction for entry into the loop + // Returns nullptr if it can't be found. + Instruction* GetConditionInst() const; + + // Returns the context associated this loop. + IRContext* GetContext() const { return context_; } + + // Looks at all the blocks with a branch to the header block to find one + // which is also dominated by the loop continue block. This block is the latch + // block. The specification mandates that this block should exist, therefore + // this function will assert if it is not found. + BasicBlock* FindLatchBlock(); + private: IRContext* context_; // The block which marks the start of the loop. @@ -348,6 +368,9 @@ class Loop { // The block immediately before the loop header. BasicBlock* loop_preheader_; + // The block containing the backedge to the loop header. + BasicBlock* loop_latch_; + // A parent of a loop is the loop which contains it as a nested child loop. Loop* parent_; @@ -364,11 +387,11 @@ class Loop { bool IsBasicBlockInLoopSlow(const BasicBlock* bb); // Returns the loop preheader if it exists, returns nullptr otherwise. - BasicBlock* FindLoopPreheader(opt::DominatorAnalysis* dom_analysis); + BasicBlock* FindLoopPreheader(DominatorAnalysis* dom_analysis); - // Sets |latch| as the loop unique continue block. No checks are performed + // Sets |latch| as the loop unique latch block. No checks are performed // here. - inline void SetLatchBlockImpl(BasicBlock* latch) { loop_continue_ = latch; } + inline void SetLatchBlockImpl(BasicBlock* latch) { loop_latch_ = latch; } // Sets |merge| as the loop merge block. No checks are performed here. inline void SetMergeBlockImpl(BasicBlock* merge) { loop_merge_ = merge; } @@ -395,11 +418,14 @@ class Loop { class LoopDescriptor { public: // Iterator interface (depth first postorder traversal). - using iterator = opt::PostOrderTreeDFIterator; - using const_iterator = opt::PostOrderTreeDFIterator; + using iterator = PostOrderTreeDFIterator; + using const_iterator = PostOrderTreeDFIterator; + + using pre_iterator = TreeDFIterator; + using const_pre_iterator = TreeDFIterator; // Creates a loop object for all loops found in |f|. - explicit LoopDescriptor(const Function* f); + LoopDescriptor(IRContext* context, const Function* f); // Disable copy constructor, to avoid double-free on destruction. LoopDescriptor(const LoopDescriptor&) = delete; @@ -428,6 +454,10 @@ class LoopDescriptor { return *loops_[index]; } + // Returns the loops in |this| in the order their headers appear in the + // binary. + std::vector GetLoopsInBinaryLayoutOrder(); + // Returns the inner most loop that contains the basic block id |block_id|. inline Loop* operator[](uint32_t block_id) const { return FindLoopForBasicBlock(block_id); @@ -451,6 +481,17 @@ class LoopDescriptor { return const_iterator::end(&dummy_top_loop_); } + // Iterators for pre-order depth first traversal of the loops. + // Inner most loops will be visited first. + inline pre_iterator pre_begin() { return ++pre_iterator(&dummy_top_loop_); } + inline pre_iterator pre_end() { return pre_iterator(); } + inline const_pre_iterator pre_begin() const { return pre_cbegin(); } + inline const_pre_iterator pre_end() const { return pre_cend(); } + inline const_pre_iterator pre_cbegin() const { + return ++const_pre_iterator(&dummy_top_loop_); + } + inline const_pre_iterator pre_cend() const { return const_pre_iterator(); } + // Returns the inner most loop that contains the basic block |bb|. inline void SetBasicBlockToLoop(uint32_t bb_id, Loop* loop) { basic_block_to_loop_[bb_id] = loop; @@ -458,10 +499,14 @@ class LoopDescriptor { // Mark the loop |loop_to_add| as needing to be added when the user calls // PostModificationCleanup. |parent| may be null. - inline void AddLoop(ir::Loop* loop_to_add, ir::Loop* parent) { + inline void AddLoop(Loop* loop_to_add, Loop* parent) { loops_to_add_.emplace_back(std::make_pair(parent, loop_to_add)); } + // Checks all loops in |this| and will create pre-headers for all loops + // that don't have one. Returns |true| if any blocks were created. + bool CreatePreHeaderBlocksIfMissing(); + // Should be called to preserve the LoopAnalysis after loops have been marked // for addition with AddLoop or MarkLoopForRemoval. void PostModificationCleanup(); @@ -473,12 +518,12 @@ class LoopDescriptor { // Adds the loop |new_loop| and all its nested loops to the descriptor set. // The object takes ownership of all the loops. - ir::Loop* AddLoopNest(std::unique_ptr new_loop); + Loop* AddLoopNest(std::unique_ptr new_loop); // Remove the loop |loop|. - void RemoveLoop(ir::Loop* loop); + void RemoveLoop(Loop* loop); - void SetAsTopLoop(ir::Loop* loop) { + void SetAsTopLoop(Loop* loop) { assert(std::find(dummy_top_loop_.begin(), dummy_top_loop_.end(), loop) == dummy_top_loop_.end() && "already registered"); @@ -495,7 +540,7 @@ class LoopDescriptor { using LoopsToAddContainerType = std::vector>; // Creates loop descriptors for the function |f|. - void PopulateList(const Function* f); + void PopulateList(IRContext* context, const Function* f); // Returns the inner most loop that contains the basic block id |block_id|. inline Loop* FindLoopForBasicBlock(uint32_t block_id) const { @@ -521,7 +566,7 @@ class LoopDescriptor { LoopsToAddContainerType loops_to_add_; }; -} // namespace ir +} // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_LOOP_DESCRIPTORS_H_ +#endif // SOURCE_OPT_LOOP_DESCRIPTOR_H_ diff --git a/3rdparty/spirv-tools/source/opt/loop_fission.cpp b/3rdparty/spirv-tools/source/opt/loop_fission.cpp new file mode 100644 index 000000000..0052406dd --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/loop_fission.cpp @@ -0,0 +1,512 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/loop_fission.h" + +#include + +#include "source/opt/register_pressure.h" + +// Implement loop fission with an optional parameter to split only +// if the register pressure in a given loop meets a certain criteria. This is +// controlled via the constructors of LoopFissionPass. +// +// 1 - Build a list of loops to be split, these are top level loops (loops +// without child loops themselves) which meet the register pressure criteria, as +// determined by the ShouldSplitLoop method of LoopFissionPass. +// +// 2 - For each loop in the list, group each instruction into a set of related +// instructions by traversing each instructions users and operands recursively. +// We stop if we encounter an instruction we have seen before or an instruction +// which we don't consider relevent (i.e OpLoopMerge). We then group these +// groups into two different sets, one for the first loop and one for the +// second. +// +// 3 - We then run CanPerformSplit to check that it would be legal to split a +// loop using those two sets. We check that we haven't altered the relative +// order load/stores appear in the binary and that we aren't breaking any +// dependency between load/stores by splitting them into two loops. We also +// check that none of the OpBranch instructions are dependent on a load as we +// leave control flow structure intact and move only instructions in the body so +// we want to avoid any loads with side affects or aliasing. +// +// 4 - We then split the loop by calling SplitLoop. This function clones the +// loop and attaches it to the preheader and connects the new loops merge block +// to the current loop header block. We then use the two sets built in step 2 to +// remove instructions from each loop. If an instruction appears in the first +// set it is removed from the second loop and vice versa. +// +// 5 - If the multiple split passes flag is set we check if each of the loops +// still meet the register pressure criteria. If they do then we add them to the +// list of loops to be split (created in step one) to allow for loops to be +// split multiple times. +// + +namespace spvtools { +namespace opt { + +class LoopFissionImpl { + public: + LoopFissionImpl(IRContext* context, Loop* loop) + : context_(context), loop_(loop), load_used_in_condition_(false) {} + + // Group each instruction in the loop into sets of instructions related by + // their usedef chains. An instruction which uses another will appear in the + // same set. Then merge those sets into just two sets. Returns false if there + // was one or less sets created. + bool GroupInstructionsByUseDef(); + + // Check if the sets built by GroupInstructionsByUseDef violate any data + // dependence rules. + bool CanPerformSplit(); + + // Split the loop and return a pointer to the new loop. + Loop* SplitLoop(); + + // Checks if |inst| is safe to move. We can only move instructions which don't + // have any side effects and OpLoads and OpStores. + bool MovableInstruction(const Instruction& inst) const; + + private: + // Traverse the def use chain of |inst| and add the users and uses of |inst| + // which are in the same loop to the |returned_set|. + void TraverseUseDef(Instruction* inst, std::set* returned_set, + bool ignore_phi_users = false, bool report_loads = false); + + // We group the instructions in the block into two different groups, the + // instructions to be kept in the original loop and the ones to be cloned into + // the new loop. As the cloned loop is attached to the preheader it will be + // the first loop and the second loop will be the original. + std::set cloned_loop_instructions_; + std::set original_loop_instructions_; + + // We need a set of all the instructions to be seen so we can break any + // recursion and also so we can ignore certain instructions by preemptively + // adding them to this set. + std::set seen_instructions_; + + // A map of instructions to their relative position in the function. + std::map instruction_order_; + + IRContext* context_; + + Loop* loop_; + + // This is set to true by TraverseUseDef when traversing the instructions + // related to the loop condition and any if conditions should any of those + // instructions be a load. + bool load_used_in_condition_; +}; + +bool LoopFissionImpl::MovableInstruction(const Instruction& inst) const { + return inst.opcode() == SpvOp::SpvOpLoad || + inst.opcode() == SpvOp::SpvOpStore || + inst.opcode() == SpvOp::SpvOpSelectionMerge || + inst.opcode() == SpvOp::SpvOpPhi || inst.IsOpcodeCodeMotionSafe(); +} + +void LoopFissionImpl::TraverseUseDef(Instruction* inst, + std::set* returned_set, + bool ignore_phi_users, bool report_loads) { + assert(returned_set && "Set to be returned cannot be null."); + + analysis::DefUseManager* def_use = context_->get_def_use_mgr(); + std::set& inst_set = *returned_set; + + // We create this functor to traverse the use def chain to build the + // grouping of related instructions. The lambda captures the std::function + // to allow it to recurse. + std::function traverser_functor; + traverser_functor = [this, def_use, &inst_set, &traverser_functor, + ignore_phi_users, report_loads](Instruction* user) { + // If we've seen the instruction before or it is not inside the loop end the + // traversal. + if (!user || seen_instructions_.count(user) != 0 || + !context_->get_instr_block(user) || + !loop_->IsInsideLoop(context_->get_instr_block(user))) { + return; + } + + // Don't include labels or loop merge instructions in the instruction sets. + // Including them would mean we group instructions related only by using the + // same labels (i.e phis). We already preempt the inclusion of + // OpSelectionMerge by adding related instructions to the seen_instructions_ + // set. + if (user->opcode() == SpvOp::SpvOpLoopMerge || + user->opcode() == SpvOp::SpvOpLabel) + return; + + // If the |report_loads| flag is set, set the class field + // load_used_in_condition_ to false. This is used to check that none of the + // condition checks in the loop rely on loads. + if (user->opcode() == SpvOp::SpvOpLoad && report_loads) { + load_used_in_condition_ = true; + } + + // Add the instruction to the set of instructions already seen, this breaks + // recursion and allows us to ignore certain instructions. + seen_instructions_.insert(user); + + inst_set.insert(user); + + // Wrapper functor to traverse the operands of each instruction. + auto traverse_operand = [&traverser_functor, def_use](const uint32_t* id) { + traverser_functor(def_use->GetDef(*id)); + }; + user->ForEachInOperand(traverse_operand); + + // For the first traversal we want to ignore the users of the phi. + if (ignore_phi_users && user->opcode() == SpvOp::SpvOpPhi) return; + + // Traverse each user with this lambda. + def_use->ForEachUser(user, traverser_functor); + + // Wrapper functor for the use traversal. + auto traverse_use = [&traverser_functor](Instruction* use, uint32_t) { + traverser_functor(use); + }; + def_use->ForEachUse(user, traverse_use); + + }; + + // We start the traversal of the use def graph by invoking the above + // lambda with the |inst| parameter. + traverser_functor(inst); +} + +bool LoopFissionImpl::GroupInstructionsByUseDef() { + std::vector> sets{}; + + // We want to ignore all the instructions stemming from the loop condition + // instruction. + BasicBlock* condition_block = loop_->FindConditionBlock(); + + if (!condition_block) return false; + Instruction* condition = &*condition_block->tail(); + + // We iterate over the blocks via iterating over all the blocks in the + // function, we do this so we are iterating in the same order which the blocks + // appear in the binary. + Function& function = *loop_->GetHeaderBlock()->GetParent(); + + // Create a temporary set to ignore certain groups of instructions within the + // loop. We don't want any instructions related to control flow to be removed + // from either loop only instructions within the control flow bodies. + std::set instructions_to_ignore{}; + TraverseUseDef(condition, &instructions_to_ignore, true, true); + + // Traverse control flow instructions to ensure they are added to the + // seen_instructions_ set and will be ignored when it it called with actual + // sets. + for (BasicBlock& block : function) { + if (!loop_->IsInsideLoop(block.id())) continue; + + for (Instruction& inst : block) { + // Ignore all instructions related to control flow. + if (inst.opcode() == SpvOp::SpvOpSelectionMerge || inst.IsBranch()) { + TraverseUseDef(&inst, &instructions_to_ignore, true, true); + } + } + } + + // Traverse the instructions and generate the sets, automatically ignoring any + // instructions in instructions_to_ignore. + for (BasicBlock& block : function) { + if (!loop_->IsInsideLoop(block.id()) || + loop_->GetHeaderBlock()->id() == block.id()) + continue; + + for (Instruction& inst : block) { + // Record the order that each load/store is seen. + if (inst.opcode() == SpvOp::SpvOpLoad || + inst.opcode() == SpvOp::SpvOpStore) { + instruction_order_[&inst] = instruction_order_.size(); + } + + // Ignore instructions already seen in a traversal. + if (seen_instructions_.count(&inst) != 0) { + continue; + } + + // Build the set. + std::set inst_set{}; + TraverseUseDef(&inst, &inst_set); + if (!inst_set.empty()) sets.push_back(std::move(inst_set)); + } + } + + // If we have one or zero sets return false to indicate that due to + // insufficient instructions we couldn't split the loop into two groups and + // thus the loop can't be split any further. + if (sets.size() < 2) { + return false; + } + + // Merge the loop sets into two different sets. In CanPerformSplit we will + // validate that we don't break the relative ordering of loads/stores by doing + // this. + for (size_t index = 0; index < sets.size() / 2; ++index) { + cloned_loop_instructions_.insert(sets[index].begin(), sets[index].end()); + } + for (size_t index = sets.size() / 2; index < sets.size(); ++index) { + original_loop_instructions_.insert(sets[index].begin(), sets[index].end()); + } + + return true; +} + +bool LoopFissionImpl::CanPerformSplit() { + // Return false if any of the condition instructions in the loop depend on a + // load. + if (load_used_in_condition_) { + return false; + } + + // Build a list of all parent loops of this loop. Loop dependence analysis + // needs this structure. + std::vector loops; + Loop* parent_loop = loop_; + while (parent_loop) { + loops.push_back(parent_loop); + parent_loop = parent_loop->GetParent(); + } + + LoopDependenceAnalysis analysis{context_, loops}; + + // A list of all the stores in the cloned loop. + std::vector set_one_stores{}; + + // A list of all the loads in the cloned loop. + std::vector set_one_loads{}; + + // Populate the above lists. + for (Instruction* inst : cloned_loop_instructions_) { + if (inst->opcode() == SpvOp::SpvOpStore) { + set_one_stores.push_back(inst); + } else if (inst->opcode() == SpvOp::SpvOpLoad) { + set_one_loads.push_back(inst); + } + + // If we find any instruction which we can't move (such as a barrier), + // return false. + if (!MovableInstruction(*inst)) return false; + } + + // We need to calculate the depth of the loop to create the loop dependency + // distance vectors. + const size_t loop_depth = loop_->GetDepth(); + + // Check the dependencies between loads in the cloned loop and stores in the + // original and vice versa. + for (Instruction* inst : original_loop_instructions_) { + // If we find any instruction which we can't move (such as a barrier), + // return false. + if (!MovableInstruction(*inst)) return false; + + // Look at the dependency between the loads in the original and stores in + // the cloned loops. + if (inst->opcode() == SpvOp::SpvOpLoad) { + for (Instruction* store : set_one_stores) { + DistanceVector vec{loop_depth}; + + // If the store actually should appear after the load, return false. + // This means the store has been placed in the wrong grouping. + if (instruction_order_[store] > instruction_order_[inst]) { + return false; + } + // If not independent check the distance vector. + if (!analysis.GetDependence(store, inst, &vec)) { + for (DistanceEntry& entry : vec.GetEntries()) { + // A distance greater than zero means that the store in the cloned + // loop has a dependency on the load in the original loop. + if (entry.distance > 0) return false; + } + } + } + } else if (inst->opcode() == SpvOp::SpvOpStore) { + for (Instruction* load : set_one_loads) { + DistanceVector vec{loop_depth}; + + // If the load actually should appear after the store, return false. + if (instruction_order_[load] > instruction_order_[inst]) { + return false; + } + + // If not independent check the distance vector. + if (!analysis.GetDependence(inst, load, &vec)) { + for (DistanceEntry& entry : vec.GetEntries()) { + // A distance less than zero means the load in the cloned loop is + // dependent on the store instruction in the original loop. + if (entry.distance < 0) return false; + } + } + } + } + } + return true; +} + +Loop* LoopFissionImpl::SplitLoop() { + // Clone the loop. + LoopUtils util{context_, loop_}; + LoopUtils::LoopCloningResult clone_results; + Loop* cloned_loop = util.CloneAndAttachLoopToHeader(&clone_results); + + // Update the OpLoopMerge in the cloned loop. + cloned_loop->UpdateLoopMergeInst(); + + // Add the loop_ to the module. + Function::iterator it = + util.GetFunction()->FindBlock(loop_->GetOrCreatePreHeaderBlock()->id()); + util.GetFunction()->AddBasicBlocks(clone_results.cloned_bb_.begin(), + clone_results.cloned_bb_.end(), ++it); + loop_->SetPreHeaderBlock(cloned_loop->GetMergeBlock()); + + std::vector instructions_to_kill{}; + + // Kill all the instructions which should appear in the cloned loop but not in + // the original loop. + for (uint32_t id : loop_->GetBlocks()) { + BasicBlock* block = context_->cfg()->block(id); + + for (Instruction& inst : *block) { + // If the instruction appears in the cloned loop instruction group, kill + // it. + if (cloned_loop_instructions_.count(&inst) == 1 && + original_loop_instructions_.count(&inst) == 0) { + instructions_to_kill.push_back(&inst); + if (inst.opcode() == SpvOp::SpvOpPhi) { + context_->ReplaceAllUsesWith( + inst.result_id(), clone_results.value_map_[inst.result_id()]); + } + } + } + } + + // Kill all instructions which should appear in the original loop and not in + // the cloned loop. + for (uint32_t id : cloned_loop->GetBlocks()) { + BasicBlock* block = context_->cfg()->block(id); + for (Instruction& inst : *block) { + Instruction* old_inst = clone_results.ptr_map_[&inst]; + // If the instruction belongs to the original loop instruction group, kill + // it. + if (cloned_loop_instructions_.count(old_inst) == 0 && + original_loop_instructions_.count(old_inst) == 1) { + instructions_to_kill.push_back(&inst); + } + } + } + + for (Instruction* i : instructions_to_kill) { + context_->KillInst(i); + } + + return cloned_loop; +} + +LoopFissionPass::LoopFissionPass(const size_t register_threshold_to_split, + bool split_multiple_times) + : split_multiple_times_(split_multiple_times) { + // Split if the number of registers in the loop exceeds + // |register_threshold_to_split|. + split_criteria_ = + [register_threshold_to_split]( + const RegisterLiveness::RegionRegisterLiveness& liveness) { + return liveness.used_registers_ > register_threshold_to_split; + }; +} + +LoopFissionPass::LoopFissionPass() : split_multiple_times_(false) { + // Split by default. + split_criteria_ = [](const RegisterLiveness::RegionRegisterLiveness&) { + return true; + }; +} + +bool LoopFissionPass::ShouldSplitLoop(const Loop& loop, IRContext* c) { + LivenessAnalysis* analysis = c->GetLivenessAnalysis(); + + RegisterLiveness::RegionRegisterLiveness liveness{}; + + Function* function = loop.GetHeaderBlock()->GetParent(); + analysis->Get(function)->ComputeLoopRegisterPressure(loop, &liveness); + + return split_criteria_(liveness); +} + +Pass::Status LoopFissionPass::Process() { + bool changed = false; + + for (Function& f : *context()->module()) { + // We collect all the inner most loops in the function and run the loop + // splitting util on each. The reason we do this is to allow us to iterate + // over each, as creating new loops will invalidate the the loop iterator. + std::vector inner_most_loops{}; + LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(&f); + for (Loop& loop : loop_descriptor) { + if (!loop.HasChildren() && ShouldSplitLoop(loop, context())) { + inner_most_loops.push_back(&loop); + } + } + + // List of new loops which meet the criteria to be split again. + std::vector new_loops_to_split{}; + + while (!inner_most_loops.empty()) { + for (Loop* loop : inner_most_loops) { + LoopFissionImpl impl{context(), loop}; + + // Group the instructions in the loop into two different sets of related + // instructions. If we can't group the instructions into the two sets + // then we can't split the loop any further. + if (!impl.GroupInstructionsByUseDef()) { + continue; + } + + if (impl.CanPerformSplit()) { + Loop* second_loop = impl.SplitLoop(); + changed = true; + context()->InvalidateAnalysesExceptFor( + IRContext::kAnalysisLoopAnalysis); + + // If the newly created loop meets the criteria to be split, split it + // again. + if (ShouldSplitLoop(*second_loop, context())) + new_loops_to_split.push_back(second_loop); + + // If the original loop (now split) still meets the criteria to be + // split, split it again. + if (ShouldSplitLoop(*loop, context())) + new_loops_to_split.push_back(loop); + } + } + + // If the split multiple times flag has been set add the new loops which + // meet the splitting criteria into the list of loops to be split on the + // next iteration. + if (split_multiple_times_) { + inner_most_loops = std::move(new_loops_to_split); + } else { + break; + } + } + } + + return changed ? Pass::Status::SuccessWithChange + : Pass::Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/loop_fission.h b/3rdparty/spirv-tools/source/opt/loop_fission.h new file mode 100644 index 000000000..e7a59c185 --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/loop_fission.h @@ -0,0 +1,78 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_FISSION_H_ +#define SOURCE_OPT_LOOP_FISSION_H_ + +#include +#include +#include +#include +#include + +#include "source/opt/cfg.h" +#include "source/opt/loop_dependence.h" +#include "source/opt/loop_utils.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" +#include "source/opt/tree_iterator.h" + +namespace spvtools { +namespace opt { + +class LoopFissionPass : public Pass { + public: + // Fuction used to determine if a given loop should be split. Takes register + // pressure region for that loop as a parameter and returns true if the loop + // should be split. + using FissionCriteriaFunction = + std::function; + + // Pass built with this constructor will split all loops regardless of + // register pressure. Will not split loops more than once. + LoopFissionPass(); + + // Split the loop if the number of registers used in the loop exceeds + // |register_threshold_to_split|. |split_multiple_times| flag determines + // whether or not the pass should split loops after already splitting them + // once. + LoopFissionPass(size_t register_threshold_to_split, + bool split_multiple_times = true); + + // Split loops whose register pressure meets the criteria of |functor|. + LoopFissionPass(FissionCriteriaFunction functor, + bool split_multiple_times = true) + : split_criteria_(functor), split_multiple_times_(split_multiple_times) {} + + const char* name() const override { return "loop-fission"; } + + Pass::Status Process() override; + + // Checks if |loop| meets the register pressure criteria to be split. + bool ShouldSplitLoop(const Loop& loop, IRContext* context); + + private: + // Functor to run in ShouldSplitLoop to determine if the register pressure + // criteria is met for splitting the loop. + FissionCriteriaFunction split_criteria_; + + // Flag designating whether or not we should also split the result of + // previously split loops if they meet the register presure criteria. + bool split_multiple_times_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_FISSION_H_ diff --git a/3rdparty/spirv-tools/source/opt/loop_fusion.cpp b/3rdparty/spirv-tools/source/opt/loop_fusion.cpp new file mode 100644 index 000000000..07d171a0a --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/loop_fusion.cpp @@ -0,0 +1,730 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/loop_fusion.h" + +#include +#include + +#include "source/opt/ir_context.h" +#include "source/opt/loop_dependence.h" +#include "source/opt/loop_descriptor.h" + +namespace spvtools { +namespace opt { + +namespace { + +// Append all the loops nested in |loop| to |loops|. +void CollectChildren(Loop* loop, std::vector* loops) { + for (auto child : *loop) { + loops->push_back(child); + if (child->NumImmediateChildren() != 0) { + CollectChildren(child, loops); + } + } +} + +// Return the set of locations accessed by |stores| and |loads|. +std::set GetLocationsAccessed( + const std::map>& stores, + const std::map>& loads) { + std::set locations{}; + + for (const auto& kv : stores) { + locations.insert(std::get<0>(kv)); + } + + for (const auto& kv : loads) { + locations.insert(std::get<0>(kv)); + } + + return locations; +} + +// Append all dependences from |sources| to |destinations| to |dependences|. +void GetDependences(std::vector* dependences, + LoopDependenceAnalysis* analysis, + const std::vector& sources, + const std::vector& destinations, + size_t num_entries) { + for (auto source : sources) { + for (auto destination : destinations) { + DistanceVector dist(num_entries); + if (!analysis->GetDependence(source, destination, &dist)) { + dependences->push_back(dist); + } + } + } +} + +// Apped all instructions in |block| to |instructions|. +void AddInstructionsInBlock(std::vector* instructions, + BasicBlock* block) { + for (auto& inst : *block) { + instructions->push_back(&inst); + } + + instructions->push_back(block->GetLabelInst()); +} + +} // namespace + +bool LoopFusion::UsedInContinueOrConditionBlock(Instruction* phi_instruction, + Loop* loop) { + auto condition_block = loop->FindConditionBlock()->id(); + auto continue_block = loop->GetContinueBlock()->id(); + auto not_used = context_->get_def_use_mgr()->WhileEachUser( + phi_instruction, + [this, condition_block, continue_block](Instruction* instruction) { + auto block_id = context_->get_instr_block(instruction)->id(); + return block_id != condition_block && block_id != continue_block; + }); + + return !not_used; +} + +void LoopFusion::RemoveIfNotUsedContinueOrConditionBlock( + std::vector* instructions, Loop* loop) { + instructions->erase( + std::remove_if(std::begin(*instructions), std::end(*instructions), + [this, loop](Instruction* instruction) { + return !UsedInContinueOrConditionBlock(instruction, + loop); + }), + std::end(*instructions)); +} + +bool LoopFusion::AreCompatible() { + // Check that the loops are in the same function. + if (loop_0_->GetHeaderBlock()->GetParent() != + loop_1_->GetHeaderBlock()->GetParent()) { + return false; + } + + // Check that both loops have pre-header blocks. + if (!loop_0_->GetPreHeaderBlock() || !loop_1_->GetPreHeaderBlock()) { + return false; + } + + // Check there are no breaks. + if (context_->cfg()->preds(loop_0_->GetMergeBlock()->id()).size() != 1 || + context_->cfg()->preds(loop_1_->GetMergeBlock()->id()).size() != 1) { + return false; + } + + // Check there are no continues. + if (context_->cfg()->preds(loop_0_->GetContinueBlock()->id()).size() != 1 || + context_->cfg()->preds(loop_1_->GetContinueBlock()->id()).size() != 1) { + return false; + } + + // |GetInductionVariables| returns all OpPhi in the header. Check that both + // loops have exactly one that is used in the continue and condition blocks. + std::vector inductions_0{}, inductions_1{}; + loop_0_->GetInductionVariables(inductions_0); + RemoveIfNotUsedContinueOrConditionBlock(&inductions_0, loop_0_); + + if (inductions_0.size() != 1) { + return false; + } + + induction_0_ = inductions_0.front(); + + loop_1_->GetInductionVariables(inductions_1); + RemoveIfNotUsedContinueOrConditionBlock(&inductions_1, loop_1_); + + if (inductions_1.size() != 1) { + return false; + } + + induction_1_ = inductions_1.front(); + + if (!CheckInit()) { + return false; + } + + if (!CheckCondition()) { + return false; + } + + if (!CheckStep()) { + return false; + } + + // Check adjacency, |loop_0_| should come just before |loop_1_|. + // There is always at least one block between loops, even if it's empty. + // We'll check at most 2 preceeding blocks. + + auto pre_header_1 = loop_1_->GetPreHeaderBlock(); + + std::vector block_to_check{}; + block_to_check.push_back(pre_header_1); + + if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) { + // Follow CFG for one more block. + auto preds = context_->cfg()->preds(pre_header_1->id()); + if (preds.size() == 1) { + auto block = &*containing_function_->FindBlock(preds.front()); + if (block == loop_0_->GetMergeBlock()) { + block_to_check.push_back(block); + } else { + return false; + } + } else { + return false; + } + } + + // Check that the separating blocks are either empty or only contains a store + // to a local variable that is never read (left behind by + // '--eliminate-local-multi-store'). Also allow OpPhi, since the loop could be + // in LCSSA form. + for (auto block : block_to_check) { + for (auto& inst : *block) { + if (inst.opcode() == SpvOpStore) { + // Get the definition of the target to check it's function scope so + // there are no observable side effects. + auto variable = + context_->get_def_use_mgr()->GetDef(inst.GetSingleWordInOperand(0)); + + if (variable->opcode() != SpvOpVariable || + variable->GetSingleWordInOperand(0) != SpvStorageClassFunction) { + return false; + } + + // Check the target is never loaded. + auto is_used = false; + context_->get_def_use_mgr()->ForEachUse( + inst.GetSingleWordInOperand(0), + [&is_used](Instruction* use_inst, uint32_t) { + if (use_inst->opcode() == SpvOpLoad) { + is_used = true; + } + }); + + if (is_used) { + return false; + } + } else if (inst.opcode() == SpvOpPhi) { + if (inst.NumInOperands() != 2) { + return false; + } + } else if (inst.opcode() != SpvOpBranch) { + return false; + } + } + } + + return true; +} // namespace opt + +bool LoopFusion::ContainsBarriersOrFunctionCalls(Loop* loop) { + for (const auto& block : loop->GetBlocks()) { + for (const auto& inst : *containing_function_->FindBlock(block)) { + auto opcode = inst.opcode(); + if (opcode == SpvOpFunctionCall || opcode == SpvOpControlBarrier || + opcode == SpvOpMemoryBarrier || opcode == SpvOpTypeNamedBarrier || + opcode == SpvOpNamedBarrierInitialize || + opcode == SpvOpMemoryNamedBarrier) { + return true; + } + } + } + + return false; +} + +bool LoopFusion::CheckInit() { + int64_t loop_0_init; + if (!loop_0_->GetInductionInitValue(induction_0_, &loop_0_init)) { + return false; + } + + int64_t loop_1_init; + if (!loop_1_->GetInductionInitValue(induction_1_, &loop_1_init)) { + return false; + } + + if (loop_0_init != loop_1_init) { + return false; + } + + return true; +} + +bool LoopFusion::CheckCondition() { + auto condition_0 = loop_0_->GetConditionInst(); + auto condition_1 = loop_1_->GetConditionInst(); + + if (!loop_0_->IsSupportedCondition(condition_0->opcode()) || + !loop_1_->IsSupportedCondition(condition_1->opcode())) { + return false; + } + + if (condition_0->opcode() != condition_1->opcode()) { + return false; + } + + for (uint32_t i = 0; i < condition_0->NumInOperandWords(); ++i) { + auto arg_0 = context_->get_def_use_mgr()->GetDef( + condition_0->GetSingleWordInOperand(i)); + auto arg_1 = context_->get_def_use_mgr()->GetDef( + condition_1->GetSingleWordInOperand(i)); + + if (arg_0 == induction_0_ && arg_1 == induction_1_) { + continue; + } + + if (arg_0 == induction_0_ && arg_1 != induction_1_) { + return false; + } + + if (arg_1 == induction_1_ && arg_0 != induction_0_) { + return false; + } + + if (arg_0 != arg_1) { + return false; + } + } + + return true; +} + +bool LoopFusion::CheckStep() { + auto scalar_analysis = context_->GetScalarEvolutionAnalysis(); + SENode* induction_node_0 = scalar_analysis->SimplifyExpression( + scalar_analysis->AnalyzeInstruction(induction_0_)); + if (!induction_node_0->AsSERecurrentNode()) { + return false; + } + + SENode* induction_step_0 = + induction_node_0->AsSERecurrentNode()->GetCoefficient(); + if (!induction_step_0->AsSEConstantNode()) { + return false; + } + + SENode* induction_node_1 = scalar_analysis->SimplifyExpression( + scalar_analysis->AnalyzeInstruction(induction_1_)); + if (!induction_node_1->AsSERecurrentNode()) { + return false; + } + + SENode* induction_step_1 = + induction_node_1->AsSERecurrentNode()->GetCoefficient(); + if (!induction_step_1->AsSEConstantNode()) { + return false; + } + + if (*induction_step_0 != *induction_step_1) { + return false; + } + + return true; +} + +std::map> LoopFusion::LocationToMemOps( + const std::vector& mem_ops) { + std::map> location_map{}; + + for (auto instruction : mem_ops) { + auto access_location = context_->get_def_use_mgr()->GetDef( + instruction->GetSingleWordInOperand(0)); + + while (access_location->opcode() == SpvOpAccessChain) { + access_location = context_->get_def_use_mgr()->GetDef( + access_location->GetSingleWordInOperand(0)); + } + + location_map[access_location].push_back(instruction); + } + + return location_map; +} + +std::pair, std::vector> +LoopFusion::GetLoadsAndStoresInLoop(Loop* loop) { + std::vector loads{}; + std::vector stores{}; + + for (auto block_id : loop->GetBlocks()) { + if (block_id == loop->GetContinueBlock()->id()) { + continue; + } + + for (auto& instruction : *containing_function_->FindBlock(block_id)) { + if (instruction.opcode() == SpvOpLoad) { + loads.push_back(&instruction); + } else if (instruction.opcode() == SpvOpStore) { + stores.push_back(&instruction); + } + } + } + + return std::make_pair(loads, stores); +} + +bool LoopFusion::IsUsedInLoop(Instruction* instruction, Loop* loop) { + auto not_used = context_->get_def_use_mgr()->WhileEachUser( + instruction, [this, loop](Instruction* user) { + auto block_id = context_->get_instr_block(user)->id(); + return !loop->IsInsideLoop(block_id); + }); + + return !not_used; +} + +bool LoopFusion::IsLegal() { + assert(AreCompatible() && "Fusion can't be legal, loops are not compatible."); + + // Bail out if there are function calls as they could have side-effects that + // cause dependencies or if there are any barriers. + if (ContainsBarriersOrFunctionCalls(loop_0_) || + ContainsBarriersOrFunctionCalls(loop_1_)) { + return false; + } + + std::vector phi_instructions{}; + loop_0_->GetInductionVariables(phi_instructions); + + // Check no OpPhi in |loop_0_| is used in |loop_1_|. + for (auto phi_instruction : phi_instructions) { + if (IsUsedInLoop(phi_instruction, loop_1_)) { + return false; + } + } + + // Check no LCSSA OpPhi in merge block of |loop_0_| is used in |loop_1_|. + auto phi_used = false; + loop_0_->GetMergeBlock()->ForEachPhiInst( + [this, &phi_used](Instruction* phi_instruction) { + phi_used |= IsUsedInLoop(phi_instruction, loop_1_); + }); + + if (phi_used) { + return false; + } + + // Grab loads & stores from both loops. + auto loads_stores_0 = GetLoadsAndStoresInLoop(loop_0_); + auto loads_stores_1 = GetLoadsAndStoresInLoop(loop_1_); + + // Build memory location to operation maps. + auto load_locs_0 = LocationToMemOps(std::get<0>(loads_stores_0)); + auto store_locs_0 = LocationToMemOps(std::get<1>(loads_stores_0)); + + auto load_locs_1 = LocationToMemOps(std::get<0>(loads_stores_1)); + auto store_locs_1 = LocationToMemOps(std::get<1>(loads_stores_1)); + + // Get the locations accessed in both loops. + auto locations_0 = GetLocationsAccessed(store_locs_0, load_locs_0); + auto locations_1 = GetLocationsAccessed(store_locs_1, load_locs_1); + + std::vector potential_clashes{}; + + std::set_intersection(std::begin(locations_0), std::end(locations_0), + std::begin(locations_1), std::end(locations_1), + std::back_inserter(potential_clashes)); + + // If the loops don't access the same variables, the fusion is legal. + if (potential_clashes.empty()) { + return true; + } + + // Find variables that have at least one store. + std::vector potential_clashes_with_stores{}; + for (auto location : potential_clashes) { + if (store_locs_0.find(location) != std::end(store_locs_0) || + store_locs_1.find(location) != std::end(store_locs_1)) { + potential_clashes_with_stores.push_back(location); + } + } + + // If there are only loads to the same variables, the fusion is legal. + if (potential_clashes_with_stores.empty()) { + return true; + } + + // Else if loads and at least one store (across loops) to the same variable + // there is a potential dependence and we need to check the dependence + // distance. + + // Find all the loops in this loop nest for the dependency analysis. + std::vector loops{}; + + // Find the parents. + for (auto current_loop = loop_0_; current_loop != nullptr; + current_loop = current_loop->GetParent()) { + loops.push_back(current_loop); + } + + auto this_loop_position = loops.size() - 1; + std::reverse(std::begin(loops), std::end(loops)); + + // Find the children. + CollectChildren(loop_0_, &loops); + CollectChildren(loop_1_, &loops); + + // Check that any dependes created are legal. That means the fused loops do + // not have any dependencies with dependence distance greater than 0 that did + // not exist in the original loops. + + LoopDependenceAnalysis analysis(context_, loops); + + analysis.GetScalarEvolution()->AddLoopsToPretendAreTheSame( + {loop_0_, loop_1_}); + + for (auto location : potential_clashes_with_stores) { + // Analyse dependences from |loop_0_| to |loop_1_|. + std::vector dependences; + // Read-After-Write. + GetDependences(&dependences, &analysis, store_locs_0[location], + load_locs_1[location], loops.size()); + // Write-After-Read. + GetDependences(&dependences, &analysis, load_locs_0[location], + store_locs_1[location], loops.size()); + // Write-After-Write. + GetDependences(&dependences, &analysis, store_locs_0[location], + store_locs_1[location], loops.size()); + + // Check that the induction variables either don't appear in the subscripts + // or the dependence distance is negative. + for (const auto& dependence : dependences) { + const auto& entry = dependence.GetEntries()[this_loop_position]; + if ((entry.dependence_information == + DistanceEntry::DependenceInformation::DISTANCE && + entry.distance < 1) || + (entry.dependence_information == + DistanceEntry::DependenceInformation::IRRELEVANT)) { + continue; + } else { + return false; + } + } + } + + return true; +} + +void ReplacePhiParentWith(Instruction* inst, uint32_t orig_block, + uint32_t new_block) { + if (inst->GetSingleWordInOperand(1) == orig_block) { + inst->SetInOperand(1, {new_block}); + } else { + inst->SetInOperand(3, {new_block}); + } +} + +void LoopFusion::Fuse() { + assert(AreCompatible() && "Can't fuse, loops aren't compatible"); + assert(IsLegal() && "Can't fuse, illegal"); + + // Save the pointers/ids, won't be found in the middle of doing modifications. + auto header_1 = loop_1_->GetHeaderBlock()->id(); + auto condition_1 = loop_1_->FindConditionBlock()->id(); + auto continue_1 = loop_1_->GetContinueBlock()->id(); + auto continue_0 = loop_0_->GetContinueBlock()->id(); + auto condition_block_of_0 = loop_0_->FindConditionBlock(); + + // Find the blocks whose branches need updating. + auto first_block_of_1 = &*(++containing_function_->FindBlock(condition_1)); + auto last_block_of_1 = &*(--containing_function_->FindBlock(continue_1)); + auto last_block_of_0 = &*(--containing_function_->FindBlock(continue_0)); + + // Update the branch for |last_block_of_loop_0| to go to |first_block_of_1|. + last_block_of_0->ForEachSuccessorLabel( + [first_block_of_1](uint32_t* succ) { *succ = first_block_of_1->id(); }); + + // Update the branch for the |last_block_of_loop_1| to go to the continue + // block of |loop_0_|. + last_block_of_1->ForEachSuccessorLabel( + [this](uint32_t* succ) { *succ = loop_0_->GetContinueBlock()->id(); }); + + // Update merge block id in the header of |loop_0_| to the merge block of + // |loop_1_|. + loop_0_->GetHeaderBlock()->ForEachInst([this](Instruction* inst) { + if (inst->opcode() == SpvOpLoopMerge) { + inst->SetInOperand(0, {loop_1_->GetMergeBlock()->id()}); + } + }); + + // Update condition branch target in |loop_0_| to the merge block of + // |loop_1_|. + condition_block_of_0->ForEachInst([this](Instruction* inst) { + if (inst->opcode() == SpvOpBranchConditional) { + auto loop_0_merge_block_id = loop_0_->GetMergeBlock()->id(); + + if (inst->GetSingleWordInOperand(1) == loop_0_merge_block_id) { + inst->SetInOperand(1, {loop_1_->GetMergeBlock()->id()}); + } else { + inst->SetInOperand(2, {loop_1_->GetMergeBlock()->id()}); + } + } + }); + + // Move OpPhi instructions not corresponding to the induction variable from + // the header of |loop_1_| to the header of |loop_0_|. + std::vector instructions_to_move{}; + for (auto& instruction : *loop_1_->GetHeaderBlock()) { + if (instruction.opcode() == SpvOpPhi && &instruction != induction_1_) { + instructions_to_move.push_back(&instruction); + } + } + + for (auto& it : instructions_to_move) { + it->RemoveFromList(); + it->InsertBefore(induction_0_); + } + + // Update the OpPhi parents to the correct blocks in |loop_0_|. + loop_0_->GetHeaderBlock()->ForEachPhiInst([this](Instruction* i) { + ReplacePhiParentWith(i, loop_1_->GetPreHeaderBlock()->id(), + loop_0_->GetPreHeaderBlock()->id()); + + ReplacePhiParentWith(i, loop_1_->GetContinueBlock()->id(), + loop_0_->GetContinueBlock()->id()); + }); + + // Update instruction to block mapping & DefUseManager. + for (auto& phi_instruction : instructions_to_move) { + context_->set_instr_block(phi_instruction, loop_0_->GetHeaderBlock()); + context_->get_def_use_mgr()->AnalyzeInstUse(phi_instruction); + } + + // Replace the uses of the induction variable of |loop_1_| with that the + // induction variable of |loop_0_|. + context_->ReplaceAllUsesWith(induction_1_->result_id(), + induction_0_->result_id()); + + // Replace LCSSA OpPhi in merge block of |loop_0_|. + loop_0_->GetMergeBlock()->ForEachPhiInst([this](Instruction* instruction) { + context_->ReplaceAllUsesWith(instruction->result_id(), + instruction->GetSingleWordInOperand(0)); + }); + + // Update LCSSA OpPhi in merge block of |loop_1_|. + loop_1_->GetMergeBlock()->ForEachPhiInst( + [condition_block_of_0](Instruction* instruction) { + instruction->SetInOperand(1, {condition_block_of_0->id()}); + }); + + // Move the continue block of |loop_0_| after the last block of |loop_1_|. + containing_function_->MoveBasicBlockToAfter(continue_0, last_block_of_1); + + // Gather all instructions to be killed from |loop_1_| (induction variable + // initialisation, header, condition and continue blocks). + std::vector instr_to_delete{}; + AddInstructionsInBlock(&instr_to_delete, loop_1_->GetPreHeaderBlock()); + AddInstructionsInBlock(&instr_to_delete, loop_1_->GetHeaderBlock()); + AddInstructionsInBlock(&instr_to_delete, loop_1_->FindConditionBlock()); + AddInstructionsInBlock(&instr_to_delete, loop_1_->GetContinueBlock()); + + // There was an additional empty block between the loops, kill that too. + if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) { + AddInstructionsInBlock(&instr_to_delete, loop_0_->GetMergeBlock()); + } + + // Update the CFG, so it wouldn't need invalidating. + auto cfg = context_->cfg(); + + cfg->ForgetBlock(loop_1_->GetPreHeaderBlock()); + cfg->ForgetBlock(loop_1_->GetHeaderBlock()); + cfg->ForgetBlock(loop_1_->FindConditionBlock()); + cfg->ForgetBlock(loop_1_->GetContinueBlock()); + + if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) { + cfg->ForgetBlock(loop_0_->GetMergeBlock()); + } + + cfg->RemoveEdge(last_block_of_0->id(), loop_0_->GetContinueBlock()->id()); + cfg->AddEdge(last_block_of_0->id(), first_block_of_1->id()); + + cfg->AddEdge(last_block_of_1->id(), loop_0_->GetContinueBlock()->id()); + + cfg->AddEdge(loop_0_->GetContinueBlock()->id(), + loop_1_->GetHeaderBlock()->id()); + + cfg->AddEdge(condition_block_of_0->id(), loop_1_->GetMergeBlock()->id()); + + // Update DefUseManager. + auto def_use_mgr = context_->get_def_use_mgr(); + + // Uses of labels that are in updated branches need analysing. + def_use_mgr->AnalyzeInstUse(last_block_of_0->terminator()); + def_use_mgr->AnalyzeInstUse(last_block_of_1->terminator()); + def_use_mgr->AnalyzeInstUse(loop_0_->GetHeaderBlock()->GetLoopMergeInst()); + def_use_mgr->AnalyzeInstUse(condition_block_of_0->terminator()); + + // Update the LoopDescriptor, so it wouldn't need invalidating. + auto ld = context_->GetLoopDescriptor(containing_function_); + + // Create a copy, so the iterator wouldn't be invalidated. + std::vector loops_to_add_remove{}; + for (auto child_loop : *loop_1_) { + loops_to_add_remove.push_back(child_loop); + } + + for (auto child_loop : loops_to_add_remove) { + loop_1_->RemoveChildLoop(child_loop); + loop_0_->AddNestedLoop(child_loop); + } + + auto loop_1_blocks = loop_1_->GetBlocks(); + + for (auto block : loop_1_blocks) { + loop_1_->RemoveBasicBlock(block); + if (block != header_1 && block != condition_1 && block != continue_1) { + loop_0_->AddBasicBlock(block); + if ((*ld)[block] == loop_1_) { + ld->SetBasicBlockToLoop(block, loop_0_); + } + } + + if ((*ld)[block] == loop_1_) { + ld->ForgetBasicBlock(block); + } + } + + loop_1_->RemoveBasicBlock(loop_1_->GetPreHeaderBlock()->id()); + ld->ForgetBasicBlock(loop_1_->GetPreHeaderBlock()->id()); + + if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) { + loop_0_->RemoveBasicBlock(loop_0_->GetMergeBlock()->id()); + ld->ForgetBasicBlock(loop_0_->GetMergeBlock()->id()); + } + + loop_0_->SetMergeBlock(loop_1_->GetMergeBlock()); + + loop_1_->ClearBlocks(); + + ld->RemoveLoop(loop_1_); + + // Kill unnessecary instructions and remove all empty blocks. + for (auto inst : instr_to_delete) { + context_->KillInst(inst); + } + + containing_function_->RemoveEmptyBlocks(); + + // Invalidate analyses. + context_->InvalidateAnalysesExceptFor( + IRContext::Analysis::kAnalysisInstrToBlockMapping | + IRContext::Analysis::kAnalysisLoopAnalysis | + IRContext::Analysis::kAnalysisDefUse | IRContext::Analysis::kAnalysisCFG); +} + +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/loop_fusion.h b/3rdparty/spirv-tools/source/opt/loop_fusion.h new file mode 100644 index 000000000..d61d6783c --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/loop_fusion.h @@ -0,0 +1,114 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_FUSION_H_ +#define SOURCE_OPT_LOOP_FUSION_H_ + +#include +#include +#include +#include + +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_utils.h" +#include "source/opt/scalar_analysis.h" + +namespace spvtools { +namespace opt { + +class LoopFusion { + public: + LoopFusion(IRContext* context, Loop* loop_0, Loop* loop_1) + : context_(context), + loop_0_(loop_0), + loop_1_(loop_1), + containing_function_(loop_0->GetHeaderBlock()->GetParent()) {} + + // Checks if the |loop_0| and |loop_1| are compatible for fusion. + // That means: + // * they both have one induction variable + // * they have the same upper and lower bounds + // - same inital value + // - same condition + // * they have the same update step + // * they are adjacent, with |loop_0| appearing before |loop_1| + // * there are no break/continue in either of them + // * they both have pre-header blocks (required for ScalarEvolutionAnalysis + // and dependence checking). + bool AreCompatible(); + + // Checks if compatible |loop_0| and |loop_1| are legal to fuse. + // * fused loops do not have any dependencies with dependence distance greater + // than 0 that did not exist in the original loops. + // * there are no function calls in the loops (could have side-effects) + bool IsLegal(); + + // Perform the actual fusion of |loop_0_| and |loop_1_|. The loops have to be + // compatible and the fusion has to be legal. + void Fuse(); + + private: + // Check that the initial values are the same. + bool CheckInit(); + + // Check that the conditions are the same. + bool CheckCondition(); + + // Check that the steps are the same. + bool CheckStep(); + + // Returns |true| if |instruction| is used in the continue or condition block + // of |loop|. + bool UsedInContinueOrConditionBlock(Instruction* instruction, Loop* loop); + + // Remove entries in |instructions| that are not used in the continue or + // condition block of |loop|. + void RemoveIfNotUsedContinueOrConditionBlock( + std::vector* instructions, Loop* loop); + + // Returns |true| if |instruction| is used in |loop|. + bool IsUsedInLoop(Instruction* instruction, Loop* loop); + + // Returns |true| if |loop| has at least one barrier or function call. + bool ContainsBarriersOrFunctionCalls(Loop* loop); + + // Get all instructions in the |loop| (except in the latch block) that have + // the opcode |opcode|. + std::pair, std::vector> + GetLoadsAndStoresInLoop(Loop* loop); + + // Given a vector of memory operations (OpLoad/OpStore), constructs a map from + // variables to the loads/stores that those variables. + std::map> LocationToMemOps( + const std::vector& mem_ops); + + IRContext* context_; + + // The original loops to be fused. + Loop* loop_0_; + Loop* loop_1_; + + // The function that contains |loop_0_| and |loop_1_|. + Function* containing_function_ = nullptr; + + // The induction variables for |loop_0_| and |loop_1_|. + Instruction* induction_0_ = nullptr; + Instruction* induction_1_ = nullptr; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_FUSION_H_ diff --git a/3rdparty/spirv-tools/source/opt/loop_fusion_pass.cpp b/3rdparty/spirv-tools/source/opt/loop_fusion_pass.cpp new file mode 100644 index 000000000..bd8444ae5 --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/loop_fusion_pass.cpp @@ -0,0 +1,69 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/loop_fusion_pass.h" + +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_fusion.h" +#include "source/opt/register_pressure.h" + +namespace spvtools { +namespace opt { + +Pass::Status LoopFusionPass::Process() { + bool modified = false; + Module* module = context()->module(); + + // Process each function in the module + for (Function& f : *module) { + modified |= ProcessFunction(&f); + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool LoopFusionPass::ProcessFunction(Function* function) { + LoopDescriptor& ld = *context()->GetLoopDescriptor(function); + + // If a loop doesn't have a preheader needs then it needs to be created. Make + // sure to return Status::SuccessWithChange in that case. + auto modified = ld.CreatePreHeaderBlocksIfMissing(); + + // TODO(tremmelg): Could the only loop that |loop| could possibly be fused be + // picked out so don't have to check every loop + for (auto& loop_0 : ld) { + for (auto& loop_1 : ld) { + LoopFusion fusion(context(), &loop_0, &loop_1); + + if (fusion.AreCompatible() && fusion.IsLegal()) { + RegisterLiveness liveness(context(), function); + RegisterLiveness::RegionRegisterLiveness reg_pressure{}; + liveness.SimulateFusion(loop_0, loop_1, ®_pressure); + + if (reg_pressure.used_registers_ <= max_registers_per_loop_) { + fusion.Fuse(); + // Recurse, as the current iterators will have been invalidated. + ProcessFunction(function); + return true; + } + } + } + } + + return modified; +} + +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/loop_fusion_pass.h b/3rdparty/spirv-tools/source/opt/loop_fusion_pass.h new file mode 100644 index 000000000..3a0be6000 --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/loop_fusion_pass.h @@ -0,0 +1,51 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_FUSION_PASS_H_ +#define SOURCE_OPT_LOOP_FUSION_PASS_H_ + +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// Implements a loop fusion pass. +// This pass will look for adjacent loops that are compatible and legal to be +// fused. It will fuse all such loops as long as the register usage for the +// fused loop stays under the threshold defined by |max_registers_per_loop|. +class LoopFusionPass : public Pass { + public: + explicit LoopFusionPass(size_t max_registers_per_loop) + : Pass(), max_registers_per_loop_(max_registers_per_loop) {} + + const char* name() const override { return "loop-fusion"; } + + // Processes the given |module|. Returns Status::Failure if errors occur when + // processing. Returns the corresponding Status::Success if processing is + // succesful to indicate whether changes have been made to the modue. + Status Process() override; + + private: + // Fuse loops in |function| if compatible, legal and the fused loop won't use + // too many registers. + bool ProcessFunction(Function* function); + + // The maximum number of registers a fused loop is allowed to use. + size_t max_registers_per_loop_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_FUSION_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/loop_peeling.cpp b/3rdparty/spirv-tools/source/opt/loop_peeling.cpp index b2d1b088c..7d27480ae 100644 --- a/3rdparty/spirv-tools/source/opt/loop_peeling.cpp +++ b/3rdparty/spirv-tools/source/opt/loop_peeling.cpp @@ -13,36 +13,40 @@ // limitations under the License. #include +#include #include #include #include #include -#include "ir_builder.h" -#include "ir_context.h" -#include "loop_descriptor.h" -#include "loop_peeling.h" -#include "loop_utils.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_peeling.h" +#include "source/opt/loop_utils.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/scalar_analysis_nodes.h" namespace spvtools { namespace opt { +size_t LoopPeelingPass::code_grow_threshold_ = 1000; void LoopPeeling::DuplicateAndConnectLoop( LoopUtils::LoopCloningResult* clone_results) { - ir::CFG& cfg = *context_->cfg(); + CFG& cfg = *context_->cfg(); analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); assert(CanPeelLoop() && "Cannot peel loop!"); - std::vector ordered_loop_blocks; - ir::BasicBlock* pre_header = loop_->GetOrCreatePreHeaderBlock(); + std::vector ordered_loop_blocks; + BasicBlock* pre_header = loop_->GetOrCreatePreHeaderBlock(); loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks); cloned_loop_ = loop_utils_.CloneLoop(clone_results, ordered_loop_blocks); // Add the basic block to the function. - ir::Function::iterator it = + Function::iterator it = loop_utils_.GetFunction()->FindBlock(pre_header->id()); assert(it != loop_utils_.GetFunction()->end() && "Pre-header not found in the function."); @@ -50,7 +54,7 @@ void LoopPeeling::DuplicateAndConnectLoop( clone_results->cloned_bb_.begin(), clone_results->cloned_bb_.end(), ++it); // Make the |loop_|'s preheader the |cloned_loop_| one. - ir::BasicBlock* cloned_header = cloned_loop_->GetHeaderBlock(); + BasicBlock* cloned_header = cloned_loop_->GetHeaderBlock(); pre_header->ForEachSuccessorLabel( [cloned_header](uint32_t* succ) { *succ = cloned_header->id(); }); @@ -67,7 +71,7 @@ void LoopPeeling::DuplicateAndConnectLoop( uint32_t cloned_loop_exit = 0; for (uint32_t pred_id : cfg.preds(loop_->GetMergeBlock()->id())) { if (loop_->IsInsideLoop(pred_id)) continue; - ir::BasicBlock* bb = cfg.block(pred_id); + BasicBlock* bb = cfg.block(pred_id); assert(cloned_loop_exit == 0 && "The loop has multiple exits."); cloned_loop_exit = bb->id(); bb->ForEachSuccessorLabel([this](uint32_t* succ) { @@ -112,7 +116,7 @@ void LoopPeeling::DuplicateAndConnectLoop( // } loop_->GetHeaderBlock()->ForEachPhiInst([cloned_loop_exit, def_use_mgr, clone_results, - this](ir::Instruction* phi) { + this](Instruction* phi) { for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { if (!loop_->IsInsideLoop(phi->GetSingleWordInOperand(i + 1))) { phi->SetInOperand(i, @@ -130,21 +134,28 @@ void LoopPeeling::DuplicateAndConnectLoop( cloned_loop_->SetMergeBlock(loop_->GetOrCreatePreHeaderBlock()); } -void LoopPeeling::InsertCanonicalInductionVariable() { - ir::BasicBlock::iterator insert_point = - GetClonedLoop()->GetLatchBlock()->tail(); +void LoopPeeling::InsertCanonicalInductionVariable( + LoopUtils::LoopCloningResult* clone_results) { + if (original_loop_canonical_induction_variable_) { + canonical_induction_variable_ = + context_->get_def_use_mgr()->GetDef(clone_results->value_map_.at( + original_loop_canonical_induction_variable_->result_id())); + return; + } + + BasicBlock::iterator insert_point = GetClonedLoop()->GetLatchBlock()->tail(); if (GetClonedLoop()->GetLatchBlock()->GetMergeInst()) { --insert_point; } - InstructionBuilder builder(context_, &*insert_point, - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping); - ir::Instruction* uint_1_cst = + InstructionBuilder builder( + context_, &*insert_point, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + Instruction* uint_1_cst = builder.Add32BitConstantInteger(1, int_type_->IsSigned()); // Create the increment. // Note that we do "1 + 1" here, one of the operand should the phi // value but we don't have it yet. The operand will be set latter. - ir::Instruction* iv_inc = builder.AddIAdd( + Instruction* iv_inc = builder.AddIAdd( uint_1_cst->type_id(), uint_1_cst->result_id(), uint_1_cst->result_id()); builder.SetInsertPoint(&*GetClonedLoop()->GetHeaderBlock()->begin()); @@ -168,12 +179,12 @@ void LoopPeeling::InsertCanonicalInductionVariable() { } void LoopPeeling::GetIteratorUpdateOperations( - const ir::Loop* loop, ir::Instruction* iterator, - std::unordered_set* operations) { - opt::analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + const Loop* loop, Instruction* iterator, + std::unordered_set* operations) { + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); operations->insert(iterator); iterator->ForEachInId([def_use_mgr, loop, operations, this](uint32_t* id) { - ir::Instruction* insn = def_use_mgr->GetDef(*id); + Instruction* insn = def_use_mgr->GetDef(*id); if (insn->opcode() == SpvOpLabel) { return; } @@ -190,7 +201,7 @@ void LoopPeeling::GetIteratorUpdateOperations( // Gather the set of blocks for all the path from |entry| to |root|. static void GetBlocksInPath(uint32_t block, uint32_t entry, std::unordered_set* blocks_in_path, - const ir::CFG& cfg) { + const CFG& cfg) { for (uint32_t pid : cfg.preds(block)) { if (blocks_in_path->insert(pid).second) { if (pid != entry) { @@ -201,7 +212,7 @@ static void GetBlocksInPath(uint32_t block, uint32_t entry, } bool LoopPeeling::IsConditionCheckSideEffectFree() const { - ir::CFG& cfg = *context_->cfg(); + CFG& cfg = *context_->cfg(); // The "do-while" form does not cause issues, the algorithm takes into account // the first iteration. @@ -215,8 +226,8 @@ bool LoopPeeling::IsConditionCheckSideEffectFree() const { &blocks_in_path, cfg); for (uint32_t bb_id : blocks_in_path) { - ir::BasicBlock* bb = cfg.block(bb_id); - if (!bb->WhileEachInst([this](ir::Instruction* insn) { + BasicBlock* bb = cfg.block(bb_id); + if (!bb->WhileEachInst([this](Instruction* insn) { if (insn->IsBranch()) return true; switch (insn->opcode()) { case SpvOpLabel: @@ -237,11 +248,10 @@ bool LoopPeeling::IsConditionCheckSideEffectFree() const { } void LoopPeeling::GetIteratingExitValues() { - ir::CFG& cfg = *context_->cfg(); + CFG& cfg = *context_->cfg(); - loop_->GetHeaderBlock()->ForEachPhiInst([this](ir::Instruction* phi) { - exit_value_[phi->result_id()] = nullptr; - }); + loop_->GetHeaderBlock()->ForEachPhiInst( + [this](Instruction* phi) { exit_value_[phi->result_id()] = nullptr; }); if (!loop_->GetMergeBlock()) { return; @@ -249,7 +259,7 @@ void LoopPeeling::GetIteratingExitValues() { if (cfg.preds(loop_->GetMergeBlock()->id()).size() != 1) { return; } - opt::analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); uint32_t condition_block_id = cfg.preds(loop_->GetMergeBlock()->id())[0]; @@ -258,8 +268,8 @@ void LoopPeeling::GetIteratingExitValues() { condition_block_id) != header_pred.end(); if (do_while_form_) { loop_->GetHeaderBlock()->ForEachPhiInst( - [condition_block_id, def_use_mgr, this](ir::Instruction* phi) { - std::unordered_set operations; + [condition_block_id, def_use_mgr, this](Instruction* phi) { + std::unordered_set operations; for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { if (condition_block_id == phi->GetSingleWordInOperand(i + 1)) { @@ -270,19 +280,19 @@ void LoopPeeling::GetIteratingExitValues() { }); } else { DominatorTree* dom_tree = - &context_->GetDominatorAnalysis(loop_utils_.GetFunction(), cfg) + &context_->GetDominatorAnalysis(loop_utils_.GetFunction()) ->GetDomTree(); - ir::BasicBlock* condition_block = cfg.block(condition_block_id); + BasicBlock* condition_block = cfg.block(condition_block_id); loop_->GetHeaderBlock()->ForEachPhiInst( - [dom_tree, condition_block, this](ir::Instruction* phi) { - std::unordered_set operations; + [dom_tree, condition_block, this](Instruction* phi) { + std::unordered_set operations; // Not the back-edge value, check if the phi instruction is the only // possible candidate. GetIteratorUpdateOperations(loop_, phi, &operations); - for (ir::Instruction* insn : operations) { + for (Instruction* insn : operations) { if (insn == phi) { continue; } @@ -297,8 +307,8 @@ void LoopPeeling::GetIteratingExitValues() { } void LoopPeeling::FixExitCondition( - const std::function& condition_builder) { - ir::CFG& cfg = *context_->cfg(); + const std::function& condition_builder) { + CFG& cfg = *context_->cfg(); uint32_t condition_block_id = 0; for (uint32_t id : cfg.preds(GetClonedLoop()->GetMergeBlock()->id())) { @@ -309,10 +319,10 @@ void LoopPeeling::FixExitCondition( } assert(condition_block_id != 0 && "2nd loop in improperly connected"); - ir::BasicBlock* condition_block = cfg.block(condition_block_id); - ir::Instruction* exit_condition = condition_block->terminator(); + BasicBlock* condition_block = cfg.block(condition_block_id); + Instruction* exit_condition = condition_block->terminator(); assert(exit_condition->opcode() == SpvOpBranchConditional); - ir::BasicBlock::iterator insert_point = condition_block->tail(); + BasicBlock::iterator insert_point = condition_block->tail(); if (condition_block->GetMergeInst()) { --insert_point; } @@ -331,17 +341,17 @@ void LoopPeeling::FixExitCondition( context_->get_def_use_mgr()->AnalyzeInstUse(exit_condition); } -ir::BasicBlock* LoopPeeling::CreateBlockBefore(ir::BasicBlock* bb) { +BasicBlock* LoopPeeling::CreateBlockBefore(BasicBlock* bb) { analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); - ir::CFG& cfg = *context_->cfg(); + CFG& cfg = *context_->cfg(); assert(cfg.preds(bb->id()).size() == 1 && "More than one predecessor"); - std::unique_ptr new_bb = MakeUnique( - std::unique_ptr(new ir::Instruction( + std::unique_ptr new_bb = + MakeUnique(std::unique_ptr(new Instruction( context_, SpvOpLabel, 0, context_->TakeNextId(), {}))); new_bb->SetParent(loop_utils_.GetFunction()); // Update the loop descriptor. - ir::Loop* in_loop = (*loop_utils_.GetLoopDescriptor())[bb]; + Loop* in_loop = (*loop_utils_.GetLoopDescriptor())[bb]; if (in_loop) { in_loop->AddBasicBlock(new_bb.get()); loop_utils_.GetLoopDescriptor()->SetBasicBlockToLoop(new_bb->id(), in_loop); @@ -350,7 +360,7 @@ ir::BasicBlock* LoopPeeling::CreateBlockBefore(ir::BasicBlock* bb) { context_->set_instr_block(new_bb->GetLabelInst(), new_bb.get()); def_use_mgr->AnalyzeInstDefUse(new_bb->GetLabelInst()); - ir::BasicBlock* bb_pred = cfg.block(cfg.preds(bb->id())[0]); + BasicBlock* bb_pred = cfg.block(cfg.preds(bb->id())[0]); bb_pred->tail()->ForEachInId([bb, &new_bb](uint32_t* id) { if (*id == bb->id()) { *id = new_bb->id(); @@ -361,37 +371,36 @@ ir::BasicBlock* LoopPeeling::CreateBlockBefore(ir::BasicBlock* bb) { def_use_mgr->AnalyzeInstUse(&*bb_pred->tail()); // Update the incoming branch. - bb->ForEachPhiInst([&new_bb, def_use_mgr](ir::Instruction* phi) { + bb->ForEachPhiInst([&new_bb, def_use_mgr](Instruction* phi) { phi->SetInOperand(1, {new_bb->id()}); def_use_mgr->AnalyzeInstUse(phi); }); - InstructionBuilder(context_, new_bb.get(), - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping) + InstructionBuilder( + context_, new_bb.get(), + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping) .AddBranch(bb->id()); - cfg.AddEdge(new_bb->id(), bb->id()); + cfg.RegisterBlock(new_bb.get()); // Add the basic block to the function. - ir::Function::iterator it = loop_utils_.GetFunction()->FindBlock(bb->id()); + Function::iterator it = loop_utils_.GetFunction()->FindBlock(bb->id()); assert(it != loop_utils_.GetFunction()->end() && "Basic block not found in the function."); - ir::BasicBlock* ret = new_bb.get(); + BasicBlock* ret = new_bb.get(); loop_utils_.GetFunction()->AddBasicBlock(std::move(new_bb), it); return ret; } -ir::BasicBlock* LoopPeeling::ProtectLoop(ir::Loop* loop, - ir::Instruction* condition, - ir::BasicBlock* if_merge) { - ir::BasicBlock* if_block = loop->GetOrCreatePreHeaderBlock(); +BasicBlock* LoopPeeling::ProtectLoop(Loop* loop, Instruction* condition, + BasicBlock* if_merge) { + BasicBlock* if_block = loop->GetOrCreatePreHeaderBlock(); // Will no longer be a pre-header because of the if. loop->SetPreHeaderBlock(nullptr); // Kill the branch to the header. context_->KillInst(&*if_block->tail()); - InstructionBuilder builder(context_, if_block, - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping); + InstructionBuilder builder( + context_, if_block, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); builder.AddConditionalBranch(condition->result_id(), loop->GetHeaderBlock()->id(), if_merge->id(), if_merge->id()); @@ -407,28 +416,27 @@ void LoopPeeling::PeelBefore(uint32_t peel_factor) { DuplicateAndConnectLoop(&clone_results); // Add a canonical induction variable "canonical_induction_variable_". - InsertCanonicalInductionVariable(); + InsertCanonicalInductionVariable(&clone_results); - InstructionBuilder builder(context_, - &*cloned_loop_->GetPreHeaderBlock()->tail(), - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping); - ir::Instruction* factor = + InstructionBuilder builder( + context_, &*cloned_loop_->GetPreHeaderBlock()->tail(), + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + Instruction* factor = builder.Add32BitConstantInteger(peel_factor, int_type_->IsSigned()); - ir::Instruction* has_remaining_iteration = builder.AddLessThan( + Instruction* has_remaining_iteration = builder.AddLessThan( factor->result_id(), loop_iteration_count_->result_id()); - ir::Instruction* max_iteration = builder.AddSelect( + Instruction* max_iteration = builder.AddSelect( factor->type_id(), has_remaining_iteration->result_id(), factor->result_id(), loop_iteration_count_->result_id()); // Change the exit condition of the cloned loop to be (exit when become // false): // "canonical_induction_variable_" < min("factor", "loop_iteration_count_") - FixExitCondition([max_iteration, this](ir::Instruction* insert_before_point) { + FixExitCondition([max_iteration, this](Instruction* insert_before_point) { return InstructionBuilder(context_, insert_before_point, - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping) + IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping) .AddLessThan(canonical_induction_variable_->result_id(), max_iteration->result_id()) ->result_id(); @@ -436,15 +444,15 @@ void LoopPeeling::PeelBefore(uint32_t peel_factor) { // "Protect" the second loop: the second loop can only be executed if // |has_remaining_iteration| is true (i.e. factor < loop_iteration_count_). - ir::BasicBlock* if_merge_block = loop_->GetMergeBlock(); + BasicBlock* if_merge_block = loop_->GetMergeBlock(); loop_->SetMergeBlock(CreateBlockBefore(loop_->GetMergeBlock())); // Prevent the second loop from being executed if we already executed all the // required iterations. - ir::BasicBlock* if_block = + BasicBlock* if_block = ProtectLoop(loop_, has_remaining_iteration, if_merge_block); // Patch the phi of the merge block. if_merge_block->ForEachPhiInst( - [&clone_results, if_block, this](ir::Instruction* phi) { + [&clone_results, if_block, this](Instruction* phi) { // if_merge_block had previously only 1 predecessor. uint32_t incoming_value = phi->GetSingleWordInOperand(0); auto def_in_loop = clone_results.value_map_.find(incoming_value); @@ -458,9 +466,8 @@ void LoopPeeling::PeelBefore(uint32_t peel_factor) { }); context_->InvalidateAnalysesExceptFor( - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping | - ir::IRContext::kAnalysisLoopAnalysis | ir::IRContext::kAnalysisCFG); + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisCFG); } void LoopPeeling::PeelAfter(uint32_t peel_factor) { @@ -471,26 +478,24 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) { DuplicateAndConnectLoop(&clone_results); // Add a canonical induction variable "canonical_induction_variable_". - InsertCanonicalInductionVariable(); + InsertCanonicalInductionVariable(&clone_results); - InstructionBuilder builder(context_, - &*cloned_loop_->GetPreHeaderBlock()->tail(), - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping); - ir::Instruction* factor = + InstructionBuilder builder( + context_, &*cloned_loop_->GetPreHeaderBlock()->tail(), + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + Instruction* factor = builder.Add32BitConstantInteger(peel_factor, int_type_->IsSigned()); - ir::Instruction* has_remaining_iteration = builder.AddLessThan( + Instruction* has_remaining_iteration = builder.AddLessThan( factor->result_id(), loop_iteration_count_->result_id()); // Change the exit condition of the cloned loop to be (exit when become // false): // "canonical_induction_variable_" + "factor" < "loop_iteration_count_" - FixExitCondition([factor, this](ir::Instruction* insert_before_point) { + FixExitCondition([factor, this](Instruction* insert_before_point) { InstructionBuilder cond_builder( context_, insert_before_point, - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping); + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); // Build the following check: canonical_induction_variable_ + factor < // iteration_count return cond_builder @@ -512,9 +517,8 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) { // Use the second loop preheader as if merge block. // Prevent the first loop if only the peeled loop needs it. - ir::BasicBlock* if_block = - ProtectLoop(cloned_loop_, has_remaining_iteration, - GetOriginalLoop()->GetPreHeaderBlock()); + BasicBlock* if_block = ProtectLoop(cloned_loop_, has_remaining_iteration, + GetOriginalLoop()->GetPreHeaderBlock()); // Patch the phi of the header block. // We added an if to enclose the first loop and because the phi node are @@ -523,25 +527,25 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) { // We had to the preheader (our if merge block) the required phi instruction // and patch the header phi. GetOriginalLoop()->GetHeaderBlock()->ForEachPhiInst( - [&clone_results, if_block, this](ir::Instruction* phi) { - opt::analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + [&clone_results, if_block, this](Instruction* phi) { + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); - auto find_value_idx = [](ir::Instruction* phi_inst, ir::Loop* loop) { + auto find_value_idx = [](Instruction* phi_inst, Loop* loop) { uint32_t preheader_value_idx = !loop->IsInsideLoop(phi_inst->GetSingleWordInOperand(1)) ? 0 : 2; return preheader_value_idx; }; - ir::Instruction* cloned_phi = + Instruction* cloned_phi = def_use_mgr->GetDef(clone_results.value_map_.at(phi->result_id())); uint32_t cloned_preheader_value = cloned_phi->GetSingleWordInOperand( find_value_idx(cloned_phi, GetClonedLoop())); - ir::Instruction* new_phi = + Instruction* new_phi = InstructionBuilder(context_, &*GetOriginalLoop()->GetPreHeaderBlock()->tail(), - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping) + IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping) .AddPhi(phi->type_id(), {phi->GetSingleWordInOperand( find_value_idx(phi, GetOriginalLoop())), @@ -554,9 +558,525 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) { }); context_->InvalidateAnalysesExceptFor( - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping | - ir::IRContext::kAnalysisLoopAnalysis | ir::IRContext::kAnalysisCFG); + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisCFG); +} + +Pass::Status LoopPeelingPass::Process() { + bool modified = false; + Module* module = context()->module(); + + // Process each function in the module + for (Function& f : *module) { + modified |= ProcessFunction(&f); + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool LoopPeelingPass::ProcessFunction(Function* f) { + bool modified = false; + LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f); + + std::vector to_process_loop; + to_process_loop.reserve(loop_descriptor.NumLoops()); + for (Loop& l : loop_descriptor) { + to_process_loop.push_back(&l); + } + + ScalarEvolutionAnalysis scev_analysis(context()); + + for (Loop* loop : to_process_loop) { + CodeMetrics loop_size; + loop_size.Analyze(*loop); + + auto try_peel = [&loop_size, &modified, this](Loop* loop_to_peel) -> Loop* { + if (!loop_to_peel->IsLCSSA()) { + LoopUtils(context(), loop_to_peel).MakeLoopClosedSSA(); + } + + bool peeled_loop; + Loop* still_peelable_loop; + std::tie(peeled_loop, still_peelable_loop) = + ProcessLoop(loop_to_peel, &loop_size); + + if (peeled_loop) { + modified = true; + } + + return still_peelable_loop; + }; + + Loop* still_peelable_loop = try_peel(loop); + // The pass is working out the maximum factor by which a loop can be peeled. + // If the loop can potentially be peeled again, then there is only one + // possible direction, so only one call is still needed. + if (still_peelable_loop) { + try_peel(loop); + } + } + + return modified; +} + +std::pair LoopPeelingPass::ProcessLoop(Loop* loop, + CodeMetrics* loop_size) { + ScalarEvolutionAnalysis* scev_analysis = + context()->GetScalarEvolutionAnalysis(); + // Default values for bailing out. + std::pair bail_out{false, nullptr}; + + BasicBlock* exit_block = loop->FindConditionBlock(); + if (!exit_block) { + return bail_out; + } + + Instruction* exiting_iv = loop->FindConditionVariable(exit_block); + if (!exiting_iv) { + return bail_out; + } + size_t iterations = 0; + if (!loop->FindNumberOfIterations(exiting_iv, &*exit_block->tail(), + &iterations)) { + return bail_out; + } + if (!iterations) { + return bail_out; + } + + Instruction* canonical_induction_variable = nullptr; + + loop->GetHeaderBlock()->WhileEachPhiInst([&canonical_induction_variable, + scev_analysis, + this](Instruction* insn) { + if (const SERecurrentNode* iv = + scev_analysis->AnalyzeInstruction(insn)->AsSERecurrentNode()) { + const SEConstantNode* offset = iv->GetOffset()->AsSEConstantNode(); + const SEConstantNode* coeff = iv->GetCoefficient()->AsSEConstantNode(); + if (offset && coeff && offset->FoldToSingleValue() == 0 && + coeff->FoldToSingleValue() == 1) { + if (context()->get_type_mgr()->GetType(insn->type_id())->AsInteger()) { + canonical_induction_variable = insn; + return false; + } + } + } + return true; + }); + + bool is_signed = canonical_induction_variable + ? context() + ->get_type_mgr() + ->GetType(canonical_induction_variable->type_id()) + ->AsInteger() + ->IsSigned() + : false; + + LoopPeeling peeler( + loop, + InstructionBuilder( + context(), loop->GetHeaderBlock(), + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping) + .Add32BitConstantInteger(static_cast(iterations), + is_signed), + canonical_induction_variable); + + if (!peeler.CanPeelLoop()) { + return bail_out; + } + + // For each basic block in the loop, check if it can be peeled. If it + // can, get the direction (before/after) and by which factor. + LoopPeelingInfo peel_info(loop, iterations, scev_analysis); + + uint32_t peel_before_factor = 0; + uint32_t peel_after_factor = 0; + + for (uint32_t block : loop->GetBlocks()) { + if (block == exit_block->id()) { + continue; + } + BasicBlock* bb = cfg()->block(block); + PeelDirection direction; + uint32_t factor; + std::tie(direction, factor) = peel_info.GetPeelingInfo(bb); + + if (direction == PeelDirection::kNone) { + continue; + } + if (direction == PeelDirection::kBefore) { + peel_before_factor = std::max(peel_before_factor, factor); + } else { + assert(direction == PeelDirection::kAfter); + peel_after_factor = std::max(peel_after_factor, factor); + } + } + PeelDirection direction = PeelDirection::kNone; + uint32_t factor = 0; + + // Find which direction we should peel. + if (peel_before_factor) { + factor = peel_before_factor; + direction = PeelDirection::kBefore; + } + if (peel_after_factor) { + if (peel_before_factor < peel_after_factor) { + // Favor a peel after here and give the peel before another shot later. + factor = peel_after_factor; + direction = PeelDirection::kAfter; + } + } + + // Do the peel if we can. + if (direction == PeelDirection::kNone) return bail_out; + + // This does not take into account branch elimination opportunities and + // the unrolling. It assumes the peeled loop will be unrolled as well. + if (factor * loop_size->roi_size_ > code_grow_threshold_) { + return bail_out; + } + loop_size->roi_size_ *= factor; + + // Find if a loop should be peeled again. + Loop* extra_opportunity = nullptr; + + if (direction == PeelDirection::kBefore) { + peeler.PeelBefore(factor); + if (stats_) { + stats_->peeled_loops_.emplace_back(loop, PeelDirection::kBefore, factor); + } + if (peel_after_factor) { + // We could have peeled after, give it another try. + extra_opportunity = peeler.GetOriginalLoop(); + } + } else { + peeler.PeelAfter(factor); + if (stats_) { + stats_->peeled_loops_.emplace_back(loop, PeelDirection::kAfter, factor); + } + if (peel_before_factor) { + // We could have peeled before, give it another try. + extra_opportunity = peeler.GetClonedLoop(); + } + } + + return {true, extra_opportunity}; +} + +uint32_t LoopPeelingPass::LoopPeelingInfo::GetFirstLoopInvariantOperand( + Instruction* condition) const { + for (uint32_t i = 0; i < condition->NumInOperands(); i++) { + BasicBlock* bb = + context_->get_instr_block(condition->GetSingleWordInOperand(i)); + if (bb && loop_->IsInsideLoop(bb)) { + return condition->GetSingleWordInOperand(i); + } + } + + return 0; +} + +uint32_t LoopPeelingPass::LoopPeelingInfo::GetFirstNonLoopInvariantOperand( + Instruction* condition) const { + for (uint32_t i = 0; i < condition->NumInOperands(); i++) { + BasicBlock* bb = + context_->get_instr_block(condition->GetSingleWordInOperand(i)); + if (!bb || !loop_->IsInsideLoop(bb)) { + return condition->GetSingleWordInOperand(i); + } + } + + return 0; +} + +static bool IsHandledCondition(SpvOp opcode) { + switch (opcode) { + case SpvOpIEqual: + case SpvOpINotEqual: + case SpvOpUGreaterThan: + case SpvOpSGreaterThan: + case SpvOpUGreaterThanEqual: + case SpvOpSGreaterThanEqual: + case SpvOpULessThan: + case SpvOpSLessThan: + case SpvOpULessThanEqual: + case SpvOpSLessThanEqual: + return true; + default: + return false; + } +} + +LoopPeelingPass::LoopPeelingInfo::Direction +LoopPeelingPass::LoopPeelingInfo::GetPeelingInfo(BasicBlock* bb) const { + if (bb->terminator()->opcode() != SpvOpBranchConditional) { + return GetNoneDirection(); + } + + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + Instruction* condition = + def_use_mgr->GetDef(bb->terminator()->GetSingleWordInOperand(0)); + + if (!IsHandledCondition(condition->opcode())) { + return GetNoneDirection(); + } + + if (!GetFirstLoopInvariantOperand(condition)) { + // No loop invariant, it cannot be peeled by this pass. + return GetNoneDirection(); + } + if (!GetFirstNonLoopInvariantOperand(condition)) { + // Seems to be a job for the unswitch pass. + return GetNoneDirection(); + } + + // Left hand-side. + SExpression lhs = scev_analysis_->AnalyzeInstruction( + def_use_mgr->GetDef(condition->GetSingleWordInOperand(0))); + if (lhs->GetType() == SENode::CanNotCompute) { + // Can't make any conclusion. + return GetNoneDirection(); + } + + // Right hand-side. + SExpression rhs = scev_analysis_->AnalyzeInstruction( + def_use_mgr->GetDef(condition->GetSingleWordInOperand(1))); + if (rhs->GetType() == SENode::CanNotCompute) { + // Can't make any conclusion. + return GetNoneDirection(); + } + + // Only take into account recurrent expression over the current loop. + bool is_lhs_rec = !scev_analysis_->IsLoopInvariant(loop_, lhs); + bool is_rhs_rec = !scev_analysis_->IsLoopInvariant(loop_, rhs); + + if ((is_lhs_rec && is_rhs_rec) || (!is_lhs_rec && !is_rhs_rec)) { + return GetNoneDirection(); + } + + if (is_lhs_rec) { + if (!lhs->AsSERecurrentNode() || + lhs->AsSERecurrentNode()->GetLoop() != loop_) { + return GetNoneDirection(); + } + } + if (is_rhs_rec) { + if (!rhs->AsSERecurrentNode() || + rhs->AsSERecurrentNode()->GetLoop() != loop_) { + return GetNoneDirection(); + } + } + + // If the op code is ==, then we try a peel before or after. + // If opcode is not <, >, <= or >=, we bail out. + // + // For the remaining cases, we canonicalize the expression so that the + // constant expression is on the left hand side and the recurring expression + // is on the right hand side. If we swap hand side, then < becomes >, <= + // becomes >= etc. + // If the opcode is <=, then we add 1 to the right hand side and do the peel + // check on <. + // If the opcode is >=, then we add 1 to the left hand side and do the peel + // check on >. + + CmpOperator cmp_operator; + switch (condition->opcode()) { + default: + return GetNoneDirection(); + case SpvOpIEqual: + case SpvOpINotEqual: + return HandleEquality(lhs, rhs); + case SpvOpUGreaterThan: + case SpvOpSGreaterThan: { + cmp_operator = CmpOperator::kGT; + break; + } + case SpvOpULessThan: + case SpvOpSLessThan: { + cmp_operator = CmpOperator::kLT; + break; + } + // We add one to transform >= into > and <= into <. + case SpvOpUGreaterThanEqual: + case SpvOpSGreaterThanEqual: { + cmp_operator = CmpOperator::kGE; + break; + } + case SpvOpULessThanEqual: + case SpvOpSLessThanEqual: { + cmp_operator = CmpOperator::kLE; + break; + } + } + + // Force the left hand side to be the non recurring expression. + if (is_lhs_rec) { + std::swap(lhs, rhs); + switch (cmp_operator) { + case CmpOperator::kLT: { + cmp_operator = CmpOperator::kGT; + break; + } + case CmpOperator::kGT: { + cmp_operator = CmpOperator::kLT; + break; + } + case CmpOperator::kLE: { + cmp_operator = CmpOperator::kGE; + break; + } + case CmpOperator::kGE: { + cmp_operator = CmpOperator::kLE; + break; + } + } + } + return HandleInequality(cmp_operator, lhs, rhs->AsSERecurrentNode()); +} + +SExpression LoopPeelingPass::LoopPeelingInfo::GetValueAtFirstIteration( + SERecurrentNode* rec) const { + return rec->GetOffset(); +} + +SExpression LoopPeelingPass::LoopPeelingInfo::GetValueAtIteration( + SERecurrentNode* rec, int64_t iteration) const { + SExpression coeff = rec->GetCoefficient(); + SExpression offset = rec->GetOffset(); + + return (coeff * iteration) + offset; +} + +SExpression LoopPeelingPass::LoopPeelingInfo::GetValueAtLastIteration( + SERecurrentNode* rec) const { + return GetValueAtIteration(rec, loop_max_iterations_ - 1); +} + +bool LoopPeelingPass::LoopPeelingInfo::EvalOperator(CmpOperator cmp_op, + SExpression lhs, + SExpression rhs, + bool* result) const { + assert(scev_analysis_->IsLoopInvariant(loop_, lhs)); + assert(scev_analysis_->IsLoopInvariant(loop_, rhs)); + // We perform the test: 0 cmp_op rhs - lhs + // What is left is then to determine the sign of the expression. + switch (cmp_op) { + case CmpOperator::kLT: { + return scev_analysis_->IsAlwaysGreaterThanZero(rhs - lhs, result); + } + case CmpOperator::kGT: { + return scev_analysis_->IsAlwaysGreaterThanZero(lhs - rhs, result); + } + case CmpOperator::kLE: { + return scev_analysis_->IsAlwaysGreaterOrEqualToZero(rhs - lhs, result); + } + case CmpOperator::kGE: { + return scev_analysis_->IsAlwaysGreaterOrEqualToZero(lhs - rhs, result); + } + } + return false; +} + +LoopPeelingPass::LoopPeelingInfo::Direction +LoopPeelingPass::LoopPeelingInfo::HandleEquality(SExpression lhs, + SExpression rhs) const { + { + // Try peel before opportunity. + SExpression lhs_cst = lhs; + if (SERecurrentNode* rec_node = lhs->AsSERecurrentNode()) { + lhs_cst = rec_node->GetOffset(); + } + SExpression rhs_cst = rhs; + if (SERecurrentNode* rec_node = rhs->AsSERecurrentNode()) { + rhs_cst = rec_node->GetOffset(); + } + + if (lhs_cst == rhs_cst) { + return Direction{LoopPeelingPass::PeelDirection::kBefore, 1}; + } + } + + { + // Try peel after opportunity. + SExpression lhs_cst = lhs; + if (SERecurrentNode* rec_node = lhs->AsSERecurrentNode()) { + // rec_node(x) = a * x + b + // assign to lhs: a * (loop_max_iterations_ - 1) + b + lhs_cst = GetValueAtLastIteration(rec_node); + } + SExpression rhs_cst = rhs; + if (SERecurrentNode* rec_node = rhs->AsSERecurrentNode()) { + // rec_node(x) = a * x + b + // assign to lhs: a * (loop_max_iterations_ - 1) + b + rhs_cst = GetValueAtLastIteration(rec_node); + } + + if (lhs_cst == rhs_cst) { + return Direction{LoopPeelingPass::PeelDirection::kAfter, 1}; + } + } + + return GetNoneDirection(); +} + +LoopPeelingPass::LoopPeelingInfo::Direction +LoopPeelingPass::LoopPeelingInfo::HandleInequality(CmpOperator cmp_op, + SExpression lhs, + SERecurrentNode* rhs) const { + SExpression offset = rhs->GetOffset(); + SExpression coefficient = rhs->GetCoefficient(); + // Compute (cst - B) / A. + std::pair flip_iteration = (lhs - offset) / coefficient; + if (!flip_iteration.first->AsSEConstantNode()) { + return GetNoneDirection(); + } + // note: !!flip_iteration.second normalize to 0/1 (via bool cast). + int64_t iteration = + flip_iteration.first->AsSEConstantNode()->FoldToSingleValue() + + !!flip_iteration.second; + if (iteration <= 0 || + loop_max_iterations_ <= static_cast(iteration)) { + // Always true or false within the loop bounds. + return GetNoneDirection(); + } + // If this is a <= or >= operator and the iteration, make sure |iteration| is + // the one flipping the condition. + // If (cst - B) and A are not divisible, this equivalent to a < or > check, so + // we skip this test. + if (!flip_iteration.second && + (cmp_op == CmpOperator::kLE || cmp_op == CmpOperator::kGE)) { + bool first_iteration; + bool current_iteration; + if (!EvalOperator(cmp_op, lhs, offset, &first_iteration) || + !EvalOperator(cmp_op, lhs, GetValueAtIteration(rhs, iteration), + ¤t_iteration)) { + return GetNoneDirection(); + } + // If the condition did not flip the next will. + if (first_iteration == current_iteration) { + iteration++; + } + } + + uint32_t cast_iteration = 0; + // sanity check: can we fit |iteration| in a uint32_t ? + if (static_cast(iteration) < std::numeric_limits::max()) { + cast_iteration = static_cast(iteration); + } + + if (cast_iteration) { + // Peel before if we are closer to the start, after if closer to the end. + if (loop_max_iterations_ / 2 > cast_iteration) { + return Direction{LoopPeelingPass::PeelDirection::kBefore, cast_iteration}; + } else { + return Direction{ + LoopPeelingPass::PeelDirection::kAfter, + static_cast(loop_max_iterations_ - cast_iteration)}; + } + } + + return GetNoneDirection(); } } // namespace opt diff --git a/3rdparty/spirv-tools/source/opt/loop_peeling.h b/3rdparty/spirv-tools/source/opt/loop_peeling.h index 9b912e560..413f896f2 100644 --- a/3rdparty/spirv-tools/source/opt/loop_peeling.h +++ b/3rdparty/spirv-tools/source/opt/loop_peeling.h @@ -18,14 +18,17 @@ #include #include #include +#include #include #include #include #include -#include "opt/ir_context.h" -#include "opt/loop_descriptor.h" -#include "opt/loop_utils.h" +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_utils.h" +#include "source/opt/pass.h" +#include "source/opt/scalar_analysis.h" namespace spvtools { namespace opt { @@ -61,13 +64,6 @@ namespace opt { // - The loop must not have any ambiguous iterators updates (see // "CanPeelLoop"). // The method "CanPeelLoop" checks that those constrained are met. -// -// FIXME(Victor): Allow the utility it accept an canonical induction variable -// rather than automatically create one. -// FIXME(Victor): When possible, evaluate the initial value of the second loop -// iterating values rather than using the exit value of the first loop. -// FIXME(Victor): Make the utility work-out the upper bound without having to -// provide it. This should become easy once the scalar evolution is in. class LoopPeeling { public: // LoopPeeling constructor. @@ -75,20 +71,33 @@ class LoopPeeling { // |loop_iteration_count| is the instruction holding the |loop| iteration // count, must be invariant for |loop| and must be of an int 32 type (signed // or unsigned). - LoopPeeling(ir::IRContext* context, ir::Loop* loop, - ir::Instruction* loop_iteration_count) - : context_(context), - loop_utils_(context, loop), + // |canonical_induction_variable| is an induction variable that can be used to + // count the number of iterations, must be of the same type as + // |loop_iteration_count| and start at 0 and increase by step of one at each + // iteration. The value nullptr is interpreted as no suitable variable exists + // and one will be created. + LoopPeeling(Loop* loop, Instruction* loop_iteration_count, + Instruction* canonical_induction_variable = nullptr) + : context_(loop->GetContext()), + loop_utils_(loop->GetContext(), loop), loop_(loop), loop_iteration_count_(!loop->IsInsideLoop(loop_iteration_count) ? loop_iteration_count : nullptr), int_type_(nullptr), + original_loop_canonical_induction_variable_( + canonical_induction_variable), canonical_induction_variable_(nullptr) { if (loop_iteration_count_) { int_type_ = context_->get_type_mgr() ->GetType(loop_iteration_count_->type_id()) ->AsInteger(); + if (canonical_induction_variable_) { + assert(canonical_induction_variable_->type_id() == + loop_iteration_count_->type_id() && + "loop_iteration_count and canonical_induction_variable do not " + "have the same type"); + } } GetIteratingExitValues(); } @@ -107,7 +116,7 @@ class LoopPeeling { // This restriction will not apply if a loop rotate is applied before (i.e. // becomes a do-while loop). bool CanPeelLoop() const { - ir::CFG& cfg = *context_->cfg(); + CFG& cfg = *context_->cfg(); if (!loop_iteration_count_) { return false; @@ -132,7 +141,7 @@ class LoopPeeling { } return !std::any_of(exit_value_.cbegin(), exit_value_.cend(), - [](std::pair it) { + [](std::pair it) { return it.second == nullptr; }); } @@ -146,31 +155,31 @@ class LoopPeeling { void PeelAfter(uint32_t factor); // Returns the cloned loop. - ir::Loop* GetClonedLoop() { return cloned_loop_; } + Loop* GetClonedLoop() { return cloned_loop_; } // Returns the original loop. - ir::Loop* GetOriginalLoop() { return loop_; } + Loop* GetOriginalLoop() { return loop_; } private: - ir::IRContext* context_; + IRContext* context_; LoopUtils loop_utils_; // The original loop. - ir::Loop* loop_; + Loop* loop_; // The initial |loop_| upper bound. - ir::Instruction* loop_iteration_count_; + Instruction* loop_iteration_count_; // The int type to use for the canonical_induction_variable_. analysis::Integer* int_type_; // The cloned loop. - ir::Loop* cloned_loop_; + Loop* cloned_loop_; // This is set to true when the exit and back-edge branch instruction is the // same. bool do_while_form_; - + // The canonical induction variable from the original loop if it exists. + Instruction* original_loop_canonical_induction_variable_; // The canonical induction variable of the cloned loop. The induction variable // is initialized to 0 and incremented by step of 1. - ir::Instruction* canonical_induction_variable_; - + Instruction* canonical_induction_variable_; // Map between loop iterators and exit values. Loop iterators - std::unordered_map exit_value_; + std::unordered_map exit_value_; // Duplicate |loop_| and place the new loop before the cloned loop. Iterating // values from the cloned loop are then connected to the original loop as @@ -179,21 +188,22 @@ class LoopPeeling { // Insert the canonical induction variable into the first loop as a simplified // counter. - void InsertCanonicalInductionVariable(); + void InsertCanonicalInductionVariable( + LoopUtils::LoopCloningResult* clone_results); // Fixes the exit condition of the before loop. The function calls // |condition_builder| to get the condition to use in the conditional branch // of the loop exit. The loop will be exited if the condition evaluate to - // true. |condition_builder| takes an ir::Instruction* that represent the + // true. |condition_builder| takes an Instruction* that represent the // insertion point. void FixExitCondition( - const std::function& condition_builder); + const std::function& condition_builder); // Gathers all operations involved in the update of |iterator| into // |operations|. void GetIteratorUpdateOperations( - const ir::Loop* loop, ir::Instruction* iterator, - std::unordered_set* operations); + const Loop* loop, Instruction* iterator, + std::unordered_set* operations); // Gathers exiting iterator values. The function builds a map between each // iterating value in the loop (a phi instruction in the loop header) and its @@ -207,14 +217,117 @@ class LoopPeeling { // Creates a new basic block and insert it between |bb| and the predecessor of // |bb|. - ir::BasicBlock* CreateBlockBefore(ir::BasicBlock* bb); + BasicBlock* CreateBlockBefore(BasicBlock* bb); // Inserts code to only execute |loop| only if the given |condition| is true. // |if_merge| is a suitable basic block to be used by the if condition as // merge block. // The function returns the if block protecting the loop. - ir::BasicBlock* ProtectLoop(ir::Loop* loop, ir::Instruction* condition, - ir::BasicBlock* if_merge); + BasicBlock* ProtectLoop(Loop* loop, Instruction* condition, + BasicBlock* if_merge); +}; + +// Implements a loop peeling optimization. +// For each loop, the pass will try to peel it if there is conditions that +// are true for the "N" first or last iterations of the loop. +// To avoid code size explosion, too large loops will not be peeled. +class LoopPeelingPass : public Pass { + public: + // Describes the peeling direction. + enum class PeelDirection { + kNone, // Cannot peel + kBefore, // Can peel before + kAfter // Can peel last + }; + + // Holds some statistics about peeled function. + struct LoopPeelingStats { + std::vector> peeled_loops_; + }; + + LoopPeelingPass(LoopPeelingStats* stats = nullptr) : stats_(stats) {} + + // Sets the loop peeling growth threshold. If the code size increase is above + // |code_grow_threshold|, the loop will not be peeled. The code size is + // measured in terms of SPIR-V instructions. + static void SetLoopPeelingThreshold(size_t code_grow_threshold) { + code_grow_threshold_ = code_grow_threshold; + } + + // Returns the loop peeling code growth threshold. + static size_t GetLoopPeelingThreshold() { return code_grow_threshold_; } + + const char* name() const override { return "loop-peeling"; } + + // Processes the given |module|. Returns Status::Failure if errors occur when + // processing. Returns the corresponding Status::Success if processing is + // succesful to indicate whether changes have been made to the modue. + Pass::Status Process() override; + + private: + // Describes the peeling direction. + enum class CmpOperator { + kLT, // less than + kGT, // greater than + kLE, // less than or equal + kGE, // greater than or equal + }; + + class LoopPeelingInfo { + public: + using Direction = std::pair; + + LoopPeelingInfo(Loop* loop, size_t loop_max_iterations, + ScalarEvolutionAnalysis* scev_analysis) + : context_(loop->GetContext()), + loop_(loop), + scev_analysis_(scev_analysis), + loop_max_iterations_(loop_max_iterations) {} + + // Returns by how much and to which direction a loop should be peeled to + // make the conditional branch of the basic block |bb| an unconditional + // branch. If |bb|'s terminator is not a conditional branch or the condition + // is not workable then it returns PeelDirection::kNone and a 0 factor. + Direction GetPeelingInfo(BasicBlock* bb) const; + + private: + // Returns the id of the loop invariant operand of the conditional + // expression |condition|. It returns if no operand is invariant. + uint32_t GetFirstLoopInvariantOperand(Instruction* condition) const; + // Returns the id of the non loop invariant operand of the conditional + // expression |condition|. It returns if all operands are invariant. + uint32_t GetFirstNonLoopInvariantOperand(Instruction* condition) const; + + // Returns the value of |rec| at the first loop iteration. + SExpression GetValueAtFirstIteration(SERecurrentNode* rec) const; + // Returns the value of |rec| at the given |iteration|. + SExpression GetValueAtIteration(SERecurrentNode* rec, + int64_t iteration) const; + // Returns the value of |rec| at the last loop iteration. + SExpression GetValueAtLastIteration(SERecurrentNode* rec) const; + + bool EvalOperator(CmpOperator cmp_op, SExpression lhs, SExpression rhs, + bool* result) const; + + Direction HandleEquality(SExpression lhs, SExpression rhs) const; + Direction HandleInequality(CmpOperator cmp_op, SExpression lhs, + SERecurrentNode* rhs) const; + + static Direction GetNoneDirection() { + return Direction{LoopPeelingPass::PeelDirection::kNone, 0}; + } + IRContext* context_; + Loop* loop_; + ScalarEvolutionAnalysis* scev_analysis_; + size_t loop_max_iterations_; + }; + // Peel profitable loops in |f|. + bool ProcessFunction(Function* f); + // Peel |loop| if profitable. + std::pair ProcessLoop(Loop* loop, CodeMetrics* loop_size); + + static size_t code_grow_threshold_; + LoopPeelingStats* stats_; }; } // namespace opt diff --git a/3rdparty/spirv-tools/source/opt/loop_unroller.cpp b/3rdparty/spirv-tools/source/opt/loop_unroller.cpp index 39aa10810..587615edf 100644 --- a/3rdparty/spirv-tools/source/opt/loop_unroller.cpp +++ b/3rdparty/spirv-tools/source/opt/loop_unroller.cpp @@ -12,12 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "opt/loop_unroller.h" +#include "source/opt/loop_unroller.h" + +#include #include #include +#include #include -#include "opt/ir_builder.h" -#include "opt/loop_utils.h" +#include + +#include "source/opt/ir_builder.h" +#include "source/opt/loop_utils.h" // Implements loop util unrolling functionality for fully and partially // unrolling loops. Given a factor it will duplicate the loop that many times, @@ -76,7 +81,7 @@ static const uint32_t kLoopControlIndex = 2; struct LoopUnrollState { LoopUnrollState() : previous_phi_(nullptr), - previous_continue_block_(nullptr), + previous_latch_block_(nullptr), previous_condition_block_(nullptr), new_phi(nullptr), new_continue_block(nullptr), @@ -84,11 +89,10 @@ struct LoopUnrollState { new_header_block(nullptr) {} // Initialize from the loop descriptor class. - LoopUnrollState(ir::Instruction* induction, ir::BasicBlock* continue_block, - ir::BasicBlock* condition, - std::vector&& phis) + LoopUnrollState(Instruction* induction, BasicBlock* latch_block, + BasicBlock* condition, std::vector&& phis) : previous_phi_(induction), - previous_continue_block_(continue_block), + previous_latch_block_(latch_block), previous_condition_block_(condition), new_phi(nullptr), new_continue_block(nullptr), @@ -100,7 +104,7 @@ struct LoopUnrollState { // Swap the state so that the new nodes are now the previous nodes. void NextIterationState() { previous_phi_ = new_phi; - previous_continue_block_ = new_continue_block; + previous_latch_block_ = new_latch_block; previous_condition_block_ = new_condition_block; previous_phis_ = std::move(new_phis_); @@ -109,6 +113,7 @@ struct LoopUnrollState { new_continue_block = nullptr; new_condition_block = nullptr; new_header_block = nullptr; + new_latch_block = nullptr; // Clear new block/instruction maps. new_blocks.clear(); @@ -117,49 +122,53 @@ struct LoopUnrollState { } // The induction variable from the immediately preceding loop body. - ir::Instruction* previous_phi_; + Instruction* previous_phi_; // All the phi nodes from the previous loop iteration. - std::vector previous_phis_; + std::vector previous_phis_; - std::vector new_phis_; - // The previous continue block. The backedge will be removed from this and - // added to the new continue block. - ir::BasicBlock* previous_continue_block_; + std::vector new_phis_; + + // The previous latch block. The backedge will be removed from this and + // added to the new latch block. + BasicBlock* previous_latch_block_; // The previous condition block. This may be folded to flatten the loop. - ir::BasicBlock* previous_condition_block_; + BasicBlock* previous_condition_block_; // The new induction variable. - ir::Instruction* new_phi; + Instruction* new_phi; // The new continue block. - ir::BasicBlock* new_continue_block; + BasicBlock* new_continue_block; // The new condition block. - ir::BasicBlock* new_condition_block; + BasicBlock* new_condition_block; // The new header block. - ir::BasicBlock* new_header_block; + BasicBlock* new_header_block; + + // The new latch block. + BasicBlock* new_latch_block; // A mapping of new block ids to the original blocks which they were copied // from. - std::unordered_map new_blocks; + std::unordered_map new_blocks; // A mapping of the original instruction ids to the instruction ids to their // copies. std::unordered_map new_inst; - std::unordered_map ids_to_new_inst; + std::unordered_map ids_to_new_inst; }; // This class implements the actual unrolling. It uses a LoopUnrollState to // maintain the state of the unrolling inbetween steps. class LoopUnrollerUtilsImpl { public: - using BasicBlockListTy = std::vector>; + using BasicBlockListTy = std::vector>; - LoopUnrollerUtilsImpl(ir::IRContext* c, ir::Function* function) + LoopUnrollerUtilsImpl(IRContext* c, Function* function) : context_(c), function_(*function), loop_condition_block_(nullptr), @@ -170,7 +179,7 @@ class LoopUnrollerUtilsImpl { // Unroll the |loop| by given |factor| by copying the whole body |factor| // times. The resulting basicblock structure will remain a loop. - void PartiallyUnroll(ir::Loop*, size_t factor); + void PartiallyUnroll(Loop*, size_t factor); // If partially unrolling the |loop| would leave the loop with too many bodies // for its number of iterations then this method should be used. This method @@ -178,131 +187,131 @@ class LoopUnrollerUtilsImpl { // successor of the original's merge block. The original loop will have its // condition changed to loop over the residual part and the duplicate will be // partially unrolled. The resulting structure will be two loops. - void PartiallyUnrollResidualFactor(ir::Loop* loop, size_t factor); + void PartiallyUnrollResidualFactor(Loop* loop, size_t factor); // Fully unroll the |loop| by copying the full body by the total number of // loop iterations, folding all conditions, and removing the backedge from the // continue block to the header. - void FullyUnroll(ir::Loop* loop); + void FullyUnroll(Loop* loop); // Get the ID of the variable in the |phi| paired with |label|. - uint32_t GetPhiDefID(const ir::Instruction* phi, uint32_t label) const; + uint32_t GetPhiDefID(const Instruction* phi, uint32_t label) const; // Close the loop by removing the OpLoopMerge from the |loop| header block and // making the backedge point to the merge block. - void CloseUnrolledLoop(ir::Loop* loop); + void CloseUnrolledLoop(Loop* loop); // Remove the OpConditionalBranch instruction inside |conditional_block| used // to branch to either exit or continue the loop and replace it with an // unconditional OpBranch to block |new_target|. - void FoldConditionBlock(ir::BasicBlock* condtion_block, uint32_t new_target); + void FoldConditionBlock(BasicBlock* condtion_block, uint32_t new_target); // Add all blocks_to_add_ to function_ at the |insert_point|. - void AddBlocksToFunction(const ir::BasicBlock* insert_point); + void AddBlocksToFunction(const BasicBlock* insert_point); // Duplicates the |old_loop|, cloning each body and remaping the ids without // removing instructions or changing relative structure. Result will be stored // in |new_loop|. - void DuplicateLoop(ir::Loop* old_loop, ir::Loop* new_loop); + void DuplicateLoop(Loop* old_loop, Loop* new_loop); inline size_t GetLoopIterationCount() const { return number_of_loop_iterations_; } // Extracts the initial state information from the |loop|. - void Init(ir::Loop* loop); + void Init(Loop* loop); // Replace the uses of each induction variable outside the loop with the final // value of the induction variable before the loop exit. To reflect the proper // state of a fully unrolled loop. - void ReplaceInductionUseWithFinalValue(ir::Loop* loop); + void ReplaceInductionUseWithFinalValue(Loop* loop); // Remove all the instructions in the invalidated_instructions_ vector. void RemoveDeadInstructions(); // Replace any use of induction variables outwith the loop with the final // value of the induction variable in the unrolled loop. - void ReplaceOutsideLoopUseWithFinalValue(ir::Loop* loop); + void ReplaceOutsideLoopUseWithFinalValue(Loop* loop); // Set the LoopControl operand of the OpLoopMerge instruction to be // DontUnroll. - void MarkLoopControlAsDontUnroll(ir::Loop* loop) const; + void MarkLoopControlAsDontUnroll(Loop* loop) const; private: // Remap all the in |basic_block| to new IDs and keep the mapping of new ids // to old // ids. |loop| is used to identify special loop blocks (header, continue, // ect). - void AssignNewResultIds(ir::BasicBlock* basic_block); + void AssignNewResultIds(BasicBlock* basic_block); // Using the map built by AssignNewResultIds, for each instruction in // |basic_block| use // that map to substitute the IDs used by instructions (in the operands) with // the new ids. - void RemapOperands(ir::BasicBlock* basic_block); + void RemapOperands(BasicBlock* basic_block); // Copy the whole body of the loop, all blocks dominated by the |loop| header // and not dominated by the |loop| merge. The copied body will be linked to by // the old |loop| continue block and the new body will link to the |loop| // header via the new continue block. |eliminate_conditions| is used to decide // whether or not to fold all the condition blocks other than the last one. - void CopyBody(ir::Loop* loop, bool eliminate_conditions); + void CopyBody(Loop* loop, bool eliminate_conditions); // Copy a given |block_to_copy| in the |loop| and record the mapping of the // old/new ids. |preserve_instructions| determines whether or not the method // will modify (other than result_id) instructions which are copied. - void CopyBasicBlock(ir::Loop* loop, const ir::BasicBlock* block_to_copy, + void CopyBasicBlock(Loop* loop, const BasicBlock* block_to_copy, bool preserve_instructions); // The actual implementation of the unroll step. Unrolls |loop| by given // |factor| by copying the body by |factor| times. Also propagates the // induction variable value throughout the copies. - void Unroll(ir::Loop* loop, size_t factor); + void Unroll(Loop* loop, size_t factor); // Fills the loop_blocks_inorder_ field with the ordered list of basic blocks // as computed by the method ComputeLoopOrderedBlocks. - void ComputeLoopOrderedBlocks(ir::Loop* loop); + void ComputeLoopOrderedBlocks(Loop* loop); // Adds the blocks_to_add_ to both the |loop| and to the parent of |loop| if // the parent exists. - void AddBlocksToLoop(ir::Loop* loop) const; + void AddBlocksToLoop(Loop* loop) const; // After the partially unroll step the phi instructions in the header block // will be in an illegal format. This function makes the phis legal by making // the edge from the latch block come from the new latch block and the value // to be the actual value of the phi at that point. - void LinkLastPhisToStart(ir::Loop* loop) const; + void LinkLastPhisToStart(Loop* loop) const; // A pointer to the IRContext. Used to add/remove instructions and for usedef // chains. - ir::IRContext* context_; + IRContext* context_; // A reference the function the loop is within. - ir::Function& function_; + Function& function_; // A list of basic blocks to be added to the loop at the end of an unroll // step. BasicBlockListTy blocks_to_add_; // List of instructions which are now dead and can be removed. - std::vector invalidated_instructions_; + std::vector invalidated_instructions_; // Maintains the current state of the transform between calls to unroll. LoopUnrollState state_; // An ordered list containing the loop basic blocks. - std::vector loop_blocks_inorder_; + std::vector loop_blocks_inorder_; // The block containing the condition check which contains a conditional // branch to the merge and continue block. - ir::BasicBlock* loop_condition_block_; + BasicBlock* loop_condition_block_; // The induction variable of the loop. - ir::Instruction* loop_induction_variable_; + Instruction* loop_induction_variable_; // Phis used in the loop need to be remapped to use the actual result values // and then be remapped at the end. - std::vector loop_phi_instructions_; + std::vector loop_phi_instructions_; // The number of loop iterations that the loop would preform pre-unroll. size_t number_of_loop_iterations_; @@ -320,8 +329,8 @@ class LoopUnrollerUtilsImpl { // Retrieve the index of the OpPhi instruction |phi| which corresponds to the // incoming |block| id. -static uint32_t GetPhiIndexFromLabel(const ir::BasicBlock* block, - const ir::Instruction* phi) { +static uint32_t GetPhiIndexFromLabel(const BasicBlock* block, + const Instruction* phi) { for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) { if (block->id() == phi->GetSingleWordInOperand(i)) { return i; @@ -331,7 +340,7 @@ static uint32_t GetPhiIndexFromLabel(const ir::BasicBlock* block, return 0; } -void LoopUnrollerUtilsImpl::Init(ir::Loop* loop) { +void LoopUnrollerUtilsImpl::Init(Loop* loop) { loop_condition_block_ = loop->FindConditionBlock(); // When we reinit the second loop during PartiallyUnrollResidualFactor we need @@ -362,12 +371,11 @@ void LoopUnrollerUtilsImpl::Init(ir::Loop* loop) { // loop it creates two loops and unrolls one and adjusts the condition on the // other. The end result being that the new loop pair iterates over the correct // number of bodies. -void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop, +void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(Loop* loop, size_t factor) { - std::unique_ptr new_label{new ir::Instruction( + std::unique_ptr new_label{new Instruction( context_, SpvOp::SpvOpLabel, 0, context_->TakeNextId(), {})}; - std::unique_ptr new_exit_bb{ - new ir::BasicBlock(std::move(new_label))}; + std::unique_ptr new_exit_bb{new BasicBlock(std::move(new_label))}; // Save the id of the block before we move it. uint32_t new_merge_id = new_exit_bb->id(); @@ -375,13 +383,13 @@ void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop, // Add the block the list of blocks to add, we want this merge block to be // right at the start of the new blocks. blocks_to_add_.push_back(std::move(new_exit_bb)); - ir::BasicBlock* new_exit_bb_raw = blocks_to_add_[0].get(); - ir::Instruction& original_conditional_branch = *loop_condition_block_->tail(); + BasicBlock* new_exit_bb_raw = blocks_to_add_[0].get(); + Instruction& original_conditional_branch = *loop_condition_block_->tail(); // Duplicate the loop, providing access to the blocks of both loops. // This is a naked new due to the VS2013 requirement of not having unique // pointers in vectors, as it will be inserted into a vector with // loop_descriptor.AddLoop. - ir::Loop* new_loop = new ir::Loop(*loop); + Loop* new_loop = new Loop(*loop); // Clear the basic blocks of the new loop. new_loop->ClearBlocks(); @@ -409,18 +417,18 @@ void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop, // Add the new merge block to the back of the list of blocks to be added. It // needs to be the last block added to maintain dominator order in the binary. blocks_to_add_.push_back( - std::unique_ptr(new_loop->GetMergeBlock())); + std::unique_ptr(new_loop->GetMergeBlock())); // Add the blocks to the function. AddBlocksToFunction(loop->GetMergeBlock()); // Reset the usedef analysis. context_->InvalidateAnalysesExceptFor( - ir::IRContext::Analysis::kAnalysisLoopAnalysis); - opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + IRContext::Analysis::kAnalysisLoopAnalysis); + analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); // The loop condition. - ir::Instruction* condition_check = def_use_manager->GetDef( + Instruction* condition_check = def_use_manager->GetDef( original_conditional_branch.GetSingleWordOperand(0)); // This should have been checked by the LoopUtils::CanPerformUnroll function @@ -428,14 +436,14 @@ void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop, assert(loop->IsSupportedCondition(condition_check->opcode())); // We need to account for the initial body when calculating the remainder. - int64_t remainder = ir::Loop::GetResidualConditionValue( + int64_t remainder = Loop::GetResidualConditionValue( condition_check->opcode(), loop_init_value_, loop_step_value_, number_of_loop_iterations_, factor); assert(remainder > std::numeric_limits::min() && remainder < std::numeric_limits::max()); - ir::Instruction* new_constant = nullptr; + Instruction* new_constant = nullptr; // If the remainder is negative then we add a signed constant, otherwise just // add an unsigned constant. @@ -456,14 +464,14 @@ void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop, // the preheader block. For the duplicated loop we need to update the constant // to be the amount of iterations covered by the first loop and the incoming // block to be the first loops new merge block. - std::vector new_inductions; + std::vector new_inductions; new_loop->GetInductionVariables(new_inductions); - std::vector old_inductions; + std::vector old_inductions; loop->GetInductionVariables(old_inductions); for (size_t index = 0; index < new_inductions.size(); ++index) { - ir::Instruction* new_induction = new_inductions[index]; - ir::Instruction* old_induction = old_inductions[index]; + Instruction* new_induction = new_inductions[index]; + Instruction* old_induction = old_inductions[index]; // Get the index of the loop initalizer, the value coming in from the // preheader. uint32_t initalizer_index = @@ -478,7 +486,7 @@ void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop, // then replace that use with the second loop induction variable. uint32_t second_loop_induction = new_induction->result_id(); auto replace_use_outside_of_loop = [loop, second_loop_induction]( - ir::Instruction* user, + Instruction* user, uint32_t operand_index) { if (!loop->IsInsideLoop(user)) { user->SetOperand(operand_index, {second_loop_induction}); @@ -490,12 +498,11 @@ void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop, } context_->InvalidateAnalysesExceptFor( - ir::IRContext::Analysis::kAnalysisLoopAnalysis); + IRContext::Analysis::kAnalysisLoopAnalysis); context_->ReplaceAllUsesWith(loop->GetMergeBlock()->id(), new_merge_id); - ir::LoopDescriptor& loop_descriptor = - *context_->GetLoopDescriptor(&function_); + LoopDescriptor& loop_descriptor = *context_->GetLoopDescriptor(&function_); loop_descriptor.AddLoop(new_loop, loop->GetParent()); @@ -504,8 +511,8 @@ void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop, // Mark this loop as DontUnroll as it will already be unrolled and it may not // be safe to unroll a previously partially unrolled loop. -void LoopUnrollerUtilsImpl::MarkLoopControlAsDontUnroll(ir::Loop* loop) const { - ir::Instruction* loop_merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst(); +void LoopUnrollerUtilsImpl::MarkLoopControlAsDontUnroll(Loop* loop) const { + Instruction* loop_merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst(); assert(loop_merge_inst && "Loop merge instruction could not be found after entering unroller " "(should have exited before this)"); @@ -516,13 +523,13 @@ void LoopUnrollerUtilsImpl::MarkLoopControlAsDontUnroll(ir::Loop* loop) const { // Duplicate the |loop| body |factor| - 1 number of times while keeping the loop // backedge intact. This will leave the loop with |factor| number of bodies // after accounting for the initial body. -void LoopUnrollerUtilsImpl::Unroll(ir::Loop* loop, size_t factor) { +void LoopUnrollerUtilsImpl::Unroll(Loop* loop, size_t factor) { // If we unroll a loop partially it will not be safe to unroll it further. // This is due to the current method of calculating the number of loop // iterations. MarkLoopControlAsDontUnroll(loop); - std::vector inductions; + std::vector inductions; loop->GetInductionVariables(inductions); state_ = LoopUnrollState{loop_induction_variable_, loop->GetLatchBlock(), loop_condition_block_, std::move(inductions)}; @@ -533,20 +540,20 @@ void LoopUnrollerUtilsImpl::Unroll(ir::Loop* loop, size_t factor) { void LoopUnrollerUtilsImpl::RemoveDeadInstructions() { // Remove the dead instructions. - for (ir::Instruction* inst : invalidated_instructions_) { + for (Instruction* inst : invalidated_instructions_) { context_->KillInst(inst); } } -void LoopUnrollerUtilsImpl::ReplaceInductionUseWithFinalValue(ir::Loop* loop) { +void LoopUnrollerUtilsImpl::ReplaceInductionUseWithFinalValue(Loop* loop) { context_->InvalidateAnalysesExceptFor( - ir::IRContext::Analysis::kAnalysisLoopAnalysis); - std::vector inductions; + IRContext::Analysis::kAnalysisLoopAnalysis); + std::vector inductions; loop->GetInductionVariables(inductions); for (size_t index = 0; index < inductions.size(); ++index) { uint32_t trip_step_id = GetPhiDefID(state_.previous_phis_[index], - state_.previous_continue_block_->id()); + state_.previous_latch_block_->id()); context_->ReplaceAllUsesWith(inductions[index]->result_id(), trip_step_id); invalidated_instructions_.push_back(inductions[index]); } @@ -554,7 +561,7 @@ void LoopUnrollerUtilsImpl::ReplaceInductionUseWithFinalValue(ir::Loop* loop) { // Fully unroll the loop by partially unrolling it by the number of loop // iterations minus one for the body already accounted for. -void LoopUnrollerUtilsImpl::FullyUnroll(ir::Loop* loop) { +void LoopUnrollerUtilsImpl::FullyUnroll(Loop* loop) { // We unroll the loop by number of iterations in the loop. Unroll(loop, number_of_loop_iterations_); @@ -581,18 +588,17 @@ void LoopUnrollerUtilsImpl::FullyUnroll(ir::Loop* loop) { RemoveDeadInstructions(); // Invalidate all analyses. context_->InvalidateAnalysesExceptFor( - ir::IRContext::Analysis::kAnalysisLoopAnalysis); + IRContext::Analysis::kAnalysisLoopAnalysis); } // Copy a given basic block, give it a new result_id, and store the new block // and the id mapping in the state. |preserve_instructions| is used to determine // whether or not this function should edit instructions other than the // |result_id|. -void LoopUnrollerUtilsImpl::CopyBasicBlock(ir::Loop* loop, - const ir::BasicBlock* itr, +void LoopUnrollerUtilsImpl::CopyBasicBlock(Loop* loop, const BasicBlock* itr, bool preserve_instructions) { // Clone the block exactly, including the IDs. - ir::BasicBlock* basic_block = itr->Clone(context_); + BasicBlock* basic_block = itr->Clone(context_); basic_block->SetParent(itr->GetParent()); // Assign each result a new unique ID and keep a mapping of the old ids to @@ -600,10 +606,10 @@ void LoopUnrollerUtilsImpl::CopyBasicBlock(ir::Loop* loop, AssignNewResultIds(basic_block); // If this is the continue block we are copying. - if (itr == loop->GetLatchBlock()) { + if (itr == loop->GetContinueBlock()) { // Make the OpLoopMerge point to this block for the continue. if (!preserve_instructions) { - ir::Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst(); + Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst(); merge_inst->SetInOperand(1, {basic_block->id()}); } @@ -616,11 +622,14 @@ void LoopUnrollerUtilsImpl::CopyBasicBlock(ir::Loop* loop, if (!preserve_instructions) { // Remove the loop merge instruction if it exists. - ir::Instruction* merge_inst = basic_block->GetLoopMergeInst(); + Instruction* merge_inst = basic_block->GetLoopMergeInst(); if (merge_inst) invalidated_instructions_.push_back(merge_inst); } } + // If this is the latch block being copied, record it in the state. + if (itr == loop->GetLatchBlock()) state_.new_latch_block = basic_block; + // If this is the condition block we are copying. if (itr == loop_condition_block_) { state_.new_condition_block = basic_block; @@ -628,38 +637,37 @@ void LoopUnrollerUtilsImpl::CopyBasicBlock(ir::Loop* loop, // Add this block to the list of blocks to add to the function at the end of // the unrolling process. - blocks_to_add_.push_back(std::unique_ptr(basic_block)); + blocks_to_add_.push_back(std::unique_ptr(basic_block)); // Keep tracking the old block via a map. state_.new_blocks[itr->id()] = basic_block; } -void LoopUnrollerUtilsImpl::CopyBody(ir::Loop* loop, - bool eliminate_conditions) { +void LoopUnrollerUtilsImpl::CopyBody(Loop* loop, bool eliminate_conditions) { // Copy each basic block in the loop, give them new ids, and save state // information. - for (const ir::BasicBlock* itr : loop_blocks_inorder_) { + for (const BasicBlock* itr : loop_blocks_inorder_) { CopyBasicBlock(loop, itr, false); } - // Set the previous continue block to point to the new header. - ir::Instruction& continue_branch = *state_.previous_continue_block_->tail(); - continue_branch.SetInOperand(0, {state_.new_header_block->id()}); + // Set the previous latch block to point to the new header. + Instruction& latch_branch = *state_.previous_latch_block_->tail(); + latch_branch.SetInOperand(0, {state_.new_header_block->id()}); // As the algorithm copies the original loop blocks exactly, the tail of the // latch block on iterations after the first one will be a branch to the new // header and not the actual loop header. The last continue block in the loop // should always be a backedge to the global header. - ir::Instruction& new_continue_branch = *state_.new_continue_block->tail(); - new_continue_branch.SetInOperand(0, {loop->GetHeaderBlock()->id()}); + Instruction& new_latch_branch = *state_.new_latch_block->tail(); + new_latch_branch.SetInOperand(0, {loop->GetHeaderBlock()->id()}); - std::vector inductions; + std::vector inductions; loop->GetInductionVariables(inductions); for (size_t index = 0; index < inductions.size(); ++index) { - ir::Instruction* master_copy = inductions[index]; + Instruction* master_copy = inductions[index]; assert(master_copy->result_id() != 0); - ir::Instruction* induction_clone = + Instruction* induction_clone = state_.ids_to_new_inst[state_.new_inst[master_copy->result_id()]]; state_.new_phis_.push_back(induction_clone); @@ -667,7 +675,7 @@ void LoopUnrollerUtilsImpl::CopyBody(ir::Loop* loop, if (!state_.previous_phis_.empty()) { state_.new_inst[master_copy->result_id()] = GetPhiDefID( - state_.previous_phis_[index], state_.previous_continue_block_->id()); + state_.previous_phis_[index], state_.previous_latch_block_->id()); } else { // Do not replace the first phi block ids. state_.new_inst[master_copy->result_id()] = master_copy->result_id(); @@ -687,14 +695,14 @@ void LoopUnrollerUtilsImpl::CopyBody(ir::Loop* loop, RemapOperands(pair.second); } - for (ir::Instruction* dead_phi : state_.new_phis_) + for (Instruction* dead_phi : state_.new_phis_) invalidated_instructions_.push_back(dead_phi); // Swap the state so the new is now the previous. state_.NextIterationState(); } -uint32_t LoopUnrollerUtilsImpl::GetPhiDefID(const ir::Instruction* phi, +uint32_t LoopUnrollerUtilsImpl::GetPhiDefID(const Instruction* phi, uint32_t label) const { for (uint32_t operand = 3; operand < phi->NumOperands(); operand += 2) { if (phi->GetSingleWordOperand(operand) == label) { @@ -705,10 +713,10 @@ uint32_t LoopUnrollerUtilsImpl::GetPhiDefID(const ir::Instruction* phi, return 0; } -void LoopUnrollerUtilsImpl::FoldConditionBlock(ir::BasicBlock* condition_block, +void LoopUnrollerUtilsImpl::FoldConditionBlock(BasicBlock* condition_block, uint32_t operand_label) { // Remove the old conditional branch to the merge and continue blocks. - ir::Instruction& old_branch = *condition_block->tail(); + Instruction& old_branch = *condition_block->tail(); uint32_t new_target = old_branch.GetSingleWordOperand(operand_label); context_->KillInst(&old_branch); @@ -717,20 +725,20 @@ void LoopUnrollerUtilsImpl::FoldConditionBlock(ir::BasicBlock* condition_block, builder.AddBranch(new_target); } -void LoopUnrollerUtilsImpl::CloseUnrolledLoop(ir::Loop* loop) { +void LoopUnrollerUtilsImpl::CloseUnrolledLoop(Loop* loop) { // Remove the OpLoopMerge instruction from the function. - ir::Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst(); + Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst(); invalidated_instructions_.push_back(merge_inst); // Remove the final backedge to the header and make it point instead to the // merge block. - state_.previous_continue_block_->tail()->SetInOperand( + state_.previous_latch_block_->tail()->SetInOperand( 0, {loop->GetMergeBlock()->id()}); // Remove all induction variables as the phis will now be invalid. Replace all // uses with the constant initializer value (all uses of phis will be in // the first iteration with the subsequent phis already having been removed). - std::vector inductions; + std::vector inductions; loop->GetInductionVariables(inductions); // We can use the state instruction mechanism to replace all internal loop @@ -739,31 +747,30 @@ void LoopUnrollerUtilsImpl::CloseUnrolledLoop(ir::Loop* loop) { // use context ReplaceAllUsesWith for the uses outside the loop with the final // trip phi value. state_.new_inst.clear(); - for (ir::Instruction* induction : inductions) { + for (Instruction* induction : inductions) { uint32_t initalizer_id = GetPhiDefID(induction, loop->GetPreHeaderBlock()->id()); state_.new_inst[induction->result_id()] = initalizer_id; } - for (ir::BasicBlock* block : loop_blocks_inorder_) { + for (BasicBlock* block : loop_blocks_inorder_) { RemapOperands(block); } } // Uses the first loop to create a copy of the loop with new IDs. -void LoopUnrollerUtilsImpl::DuplicateLoop(ir::Loop* old_loop, - ir::Loop* new_loop) { - std::vector new_block_order; +void LoopUnrollerUtilsImpl::DuplicateLoop(Loop* old_loop, Loop* new_loop) { + std::vector new_block_order; // Copy every block in the old loop. - for (const ir::BasicBlock* itr : loop_blocks_inorder_) { + for (const BasicBlock* itr : loop_blocks_inorder_) { CopyBasicBlock(old_loop, itr, true); new_block_order.push_back(blocks_to_add_.back().get()); } // Clone the merge block, give it a new id and record it in the state. - ir::BasicBlock* new_merge = old_loop->GetMergeBlock()->Clone(context_); + BasicBlock* new_merge = old_loop->GetMergeBlock()->Clone(context_); new_merge->SetParent(old_loop->GetMergeBlock()->GetParent()); AssignNewResultIds(new_merge); state_.new_blocks[old_loop->GetMergeBlock()->id()] = new_merge; @@ -779,15 +786,16 @@ void LoopUnrollerUtilsImpl::DuplicateLoop(ir::Loop* old_loop, AddBlocksToLoop(new_loop); new_loop->SetHeaderBlock(state_.new_header_block); - new_loop->SetLatchBlock(state_.new_continue_block); + new_loop->SetContinueBlock(state_.new_continue_block); + new_loop->SetLatchBlock(state_.new_latch_block); new_loop->SetMergeBlock(new_merge); } // Whenever the utility copies a block it stores it in a tempory buffer, this -// function adds the buffer into the ir::Function. The blocks will be inserted +// function adds the buffer into the Function. The blocks will be inserted // after the block |insert_point|. void LoopUnrollerUtilsImpl::AddBlocksToFunction( - const ir::BasicBlock* insert_point) { + const BasicBlock* insert_point) { for (auto basic_block_iterator = function_.begin(); basic_block_iterator != function_.end(); ++basic_block_iterator) { if (basic_block_iterator->id() == insert_point->id()) { @@ -803,7 +811,7 @@ void LoopUnrollerUtilsImpl::AddBlocksToFunction( // Assign all result_ids in |basic_block| instructions to new IDs and preserve // the mapping of new ids to old ones. -void LoopUnrollerUtilsImpl::AssignNewResultIds(ir::BasicBlock* basic_block) { +void LoopUnrollerUtilsImpl::AssignNewResultIds(BasicBlock* basic_block) { // Label instructions aren't covered by normal traversal of the // instructions. uint32_t new_label_id = context_->TakeNextId(); @@ -812,7 +820,7 @@ void LoopUnrollerUtilsImpl::AssignNewResultIds(ir::BasicBlock* basic_block) { state_.new_inst[basic_block->GetLabelInst()->result_id()] = new_label_id; basic_block->GetLabelInst()->SetResultId(new_label_id); - for (ir::Instruction& inst : *basic_block) { + for (Instruction& inst : *basic_block) { uint32_t old_id = inst.result_id(); // Ignore stores etc. @@ -836,8 +844,8 @@ void LoopUnrollerUtilsImpl::AssignNewResultIds(ir::BasicBlock* basic_block) { // For all instructions in |basic_block| check if the operands used are from a // copied instruction and if so swap out the operand for the copy of it. -void LoopUnrollerUtilsImpl::RemapOperands(ir::BasicBlock* basic_block) { - for (ir::Instruction& inst : *basic_block) { +void LoopUnrollerUtilsImpl::RemapOperands(BasicBlock* basic_block) { + for (Instruction& inst : *basic_block) { auto remap_operands_to_new_ids = [this](uint32_t* id) { auto itr = state_.new_inst.find(*id); @@ -852,13 +860,13 @@ void LoopUnrollerUtilsImpl::RemapOperands(ir::BasicBlock* basic_block) { // Generate the ordered list of basic blocks in the |loop| and cache it for // later use. -void LoopUnrollerUtilsImpl::ComputeLoopOrderedBlocks(ir::Loop* loop) { +void LoopUnrollerUtilsImpl::ComputeLoopOrderedBlocks(Loop* loop) { loop_blocks_inorder_.clear(); loop->ComputeLoopStructuredOrder(&loop_blocks_inorder_); } // Adds the blocks_to_add_ to both the loop and to the parent. -void LoopUnrollerUtilsImpl::AddBlocksToLoop(ir::Loop* loop) const { +void LoopUnrollerUtilsImpl::AddBlocksToLoop(Loop* loop) const { // Add the blocks to this loop. for (auto& block_itr : blocks_to_add_) { loop->AddBasicBlock(block_itr.get()); @@ -868,20 +876,20 @@ void LoopUnrollerUtilsImpl::AddBlocksToLoop(ir::Loop* loop) const { if (loop->GetParent()) AddBlocksToLoop(loop->GetParent()); } -void LoopUnrollerUtilsImpl::LinkLastPhisToStart(ir::Loop* loop) const { - std::vector inductions; +void LoopUnrollerUtilsImpl::LinkLastPhisToStart(Loop* loop) const { + std::vector inductions; loop->GetInductionVariables(inductions); for (size_t i = 0; i < inductions.size(); ++i) { - ir::Instruction* last_phi_in_block = state_.previous_phis_[i]; + Instruction* last_phi_in_block = state_.previous_phis_[i]; - uint32_t phi_index = GetPhiIndexFromLabel(state_.previous_continue_block_, - last_phi_in_block); + uint32_t phi_index = + GetPhiIndexFromLabel(state_.previous_latch_block_, last_phi_in_block); uint32_t phi_variable = last_phi_in_block->GetSingleWordInOperand(phi_index - 1); uint32_t phi_label = last_phi_in_block->GetSingleWordInOperand(phi_index); - ir::Instruction* phi = inductions[i]; + Instruction* phi = inductions[i]; phi->SetInOperand(phi_index - 1, {phi_variable}); phi->SetInOperand(phi_index, {phi_label}); } @@ -889,7 +897,7 @@ void LoopUnrollerUtilsImpl::LinkLastPhisToStart(ir::Loop* loop) const { // Duplicate the |loop| body |factor| number of times while keeping the loop // backedge intact. -void LoopUnrollerUtilsImpl::PartiallyUnroll(ir::Loop* loop, size_t factor) { +void LoopUnrollerUtilsImpl::PartiallyUnroll(Loop* loop, size_t factor) { Unroll(loop, factor); LinkLastPhisToStart(loop); AddBlocksToLoop(loop); @@ -916,20 +924,20 @@ bool LoopUtils::CanPerformUnroll() { } // Find check the loop has a condition we can find and evaluate. - const ir::BasicBlock* condition = loop_->FindConditionBlock(); + const BasicBlock* condition = loop_->FindConditionBlock(); if (!condition) return false; // Check that we can find and process the induction variable. - const ir::Instruction* induction = loop_->FindConditionVariable(condition); + const Instruction* induction = loop_->FindConditionVariable(condition); if (!induction || induction->opcode() != SpvOpPhi) return false; // Check that we can find the number of loop iterations. if (!loop_->FindNumberOfIterations(induction, &*condition->ctail(), nullptr)) return false; - // Make sure the continue block is a unconditional branch to the header + // Make sure the latch block is a unconditional branch to the header // block. - const ir::Instruction& branch = *loop_->GetLatchBlock()->ctail(); + const Instruction& branch = *loop_->GetLatchBlock()->ctail(); bool branching_assumption = branch.opcode() == SpvOpBranch && branch.GetSingleWordInOperand(0) == loop_->GetHeaderBlock()->id(); @@ -937,7 +945,7 @@ bool LoopUtils::CanPerformUnroll() { return false; } - std::vector inductions; + std::vector inductions; loop_->GetInductionVariables(inductions); // Ban breaks within the loop. @@ -949,7 +957,7 @@ bool LoopUtils::CanPerformUnroll() { // Ban continues within the loop. const std::vector& continue_block_preds = - context_->cfg()->preds(loop_->GetLatchBlock()->id()); + context_->cfg()->preds(loop_->GetContinueBlock()->id()); if (continue_block_preds.size() != 1) { return false; } @@ -958,7 +966,7 @@ bool LoopUtils::CanPerformUnroll() { // Iterate over all the blocks within the loop and check that none of them // exit the loop. for (uint32_t label_id : loop_->GetBlocks()) { - const ir::BasicBlock* block = context_->cfg()->block(label_id); + const BasicBlock* block = context_->cfg()->block(label_id); if (block->ctail()->opcode() == SpvOp::SpvOpKill || block->ctail()->opcode() == SpvOp::SpvOpReturn || block->ctail()->opcode() == SpvOp::SpvOpReturnValue) { @@ -1004,7 +1012,7 @@ bool LoopUtils::PartiallyUnroll(size_t factor) { bool LoopUtils::FullyUnroll() { if (!CanPerformUnroll()) return false; - std::vector inductions; + std::vector inductions; loop_->GetInductionVariables(inductions); LoopUnrollerUtilsImpl unroller{context_, @@ -1019,7 +1027,7 @@ bool LoopUtils::FullyUnroll() { void LoopUtils::Finalize() { // Clean up the loop descriptor to preserve the analysis. - ir::LoopDescriptor* LD = context_->GetLoopDescriptor(&function_); + LoopDescriptor* LD = context_->GetLoopDescriptor(&function_); LD->PostModificationCleanup(); } @@ -1029,13 +1037,12 @@ void LoopUtils::Finalize() { * */ -Pass::Status LoopUnroller::Process(ir::IRContext* c) { - context_ = c; +Pass::Status LoopUnroller::Process() { bool changed = false; - for (ir::Function& f : *c->module()) { - ir::LoopDescriptor* LD = context_->GetLoopDescriptor(&f); - for (ir::Loop& loop : *LD) { - LoopUtils loop_utils{c, &loop}; + for (Function& f : *context()->module()) { + LoopDescriptor* LD = context()->GetLoopDescriptor(&f); + for (Loop& loop : *LD) { + LoopUtils loop_utils{context(), &loop}; if (!loop.HasUnrollLoopControl() || !loop_utils.CanPerformUnroll()) { continue; } diff --git a/3rdparty/spirv-tools/source/opt/loop_unroller.h b/3rdparty/spirv-tools/source/opt/loop_unroller.h index caf0a8ed3..eb358ae24 100644 --- a/3rdparty/spirv-tools/source/opt/loop_unroller.h +++ b/3rdparty/spirv-tools/source/opt/loop_unroller.h @@ -14,7 +14,8 @@ #ifndef SOURCE_OPT_LOOP_UNROLLER_H_ #define SOURCE_OPT_LOOP_UNROLLER_H_ -#include "opt/pass.h" + +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -25,12 +26,11 @@ class LoopUnroller : public Pass { LoopUnroller(bool fully_unroll, int unroll_factor) : Pass(), fully_unroll_(fully_unroll), unroll_factor_(unroll_factor) {} - const char* name() const override { return "Loop unroller"; } + const char* name() const override { return "loop-unroll"; } - Status Process(ir::IRContext* context) override; + Status Process() override; private: - ir::IRContext* context_; bool fully_unroll_; int unroll_factor_; }; diff --git a/3rdparty/spirv-tools/source/opt/loop_unswitch_pass.cpp b/3rdparty/spirv-tools/source/opt/loop_unswitch_pass.cpp index 964e765b9..59a0cbcd3 100644 --- a/3rdparty/spirv-tools/source/opt/loop_unswitch_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/loop_unswitch_pass.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "loop_unswitch_pass.h" +#include "source/opt/loop_unswitch_pass.h" #include #include @@ -23,16 +23,16 @@ #include #include -#include "basic_block.h" -#include "dominator_tree.h" -#include "fold.h" -#include "function.h" -#include "instruction.h" -#include "ir_builder.h" -#include "ir_context.h" -#include "loop_descriptor.h" +#include "source/opt/basic_block.h" +#include "source/opt/dominator_tree.h" +#include "source/opt/fold.h" +#include "source/opt/function.h" +#include "source/opt/instruction.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" -#include "loop_utils.h" +#include "source/opt/loop_utils.h" namespace spvtools { namespace opt { @@ -52,8 +52,8 @@ namespace { // - The loop invariant condition is not uniform. class LoopUnswitch { public: - LoopUnswitch(ir::IRContext* context, ir::Function* function, ir::Loop* loop, - ir::LoopDescriptor* loop_desc) + LoopUnswitch(IRContext* context, Function* function, Loop* loop, + LoopDescriptor* loop_desc) : function_(function), loop_(loop), loop_desc_(*loop_desc), @@ -70,10 +70,10 @@ class LoopUnswitch { if (switch_block_) return true; if (loop_->IsSafeToClone()) return false; - ir::CFG& cfg = *context_->cfg(); + CFG& cfg = *context_->cfg(); for (uint32_t bb_id : loop_->GetBlocks()) { - ir::BasicBlock* bb = cfg.block(bb_id); + BasicBlock* bb = cfg.block(bb_id); if (bb->terminator()->IsBranch() && bb->terminator()->opcode() != SpvOpBranch) { if (IsConditionLoopInvariant(bb->terminator())) { @@ -87,8 +87,8 @@ class LoopUnswitch { } // Return the iterator to the basic block |bb|. - ir::Function::iterator FindBasicBlockPosition(ir::BasicBlock* bb_to_find) { - ir::Function::iterator it = function_->FindBlock(bb_to_find->id()); + Function::iterator FindBasicBlockPosition(BasicBlock* bb_to_find) { + Function::iterator it = function_->FindBlock(bb_to_find->id()); assert(it != function_->end() && "Basic Block not found"); return it; } @@ -96,11 +96,11 @@ class LoopUnswitch { // Creates a new basic block and insert it into the function |fn| at the // position |ip|. This function preserves the def/use and instr to block // managers. - ir::BasicBlock* CreateBasicBlock(ir::Function::iterator ip) { + BasicBlock* CreateBasicBlock(Function::iterator ip) { analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); - ir::BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr( - new ir::BasicBlock(std::unique_ptr(new ir::Instruction( + BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr( + new BasicBlock(std::unique_ptr(new Instruction( context_, SpvOpLabel, 0, context_->TakeNextId(), {}))))); bb->SetParent(function_); def_use_mgr->AnalyzeInstDef(bb->GetLabelInst()); @@ -116,10 +116,9 @@ class LoopUnswitch { assert(loop_->GetPreHeaderBlock() && "This loop has no pre-header block"); assert(loop_->IsLCSSA() && "This loop is not in LCSSA form"); - ir::CFG& cfg = *context_->cfg(); + CFG& cfg = *context_->cfg(); DominatorTree* dom_tree = - &context_->GetDominatorAnalysis(function_, *context_->cfg()) - ->GetDomTree(); + &context_->GetDominatorAnalysis(function_)->GetDomTree(); analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); LoopUtils loop_utils(context_, loop_); @@ -133,29 +132,28 @@ class LoopUnswitch { ////////////////////////////////////////////////////////////////////////////// // Get the merge block if it exists. - ir::BasicBlock* if_merge_block = loop_->GetMergeBlock(); + BasicBlock* if_merge_block = loop_->GetMergeBlock(); // The merge block is only created if the loop has a unique exit block. We // have this guarantee for structured loops, for compute loop it will // trivially help maintain both a structured-like form and LCSAA. - ir::BasicBlock* loop_merge_block = + BasicBlock* loop_merge_block = if_merge_block ? CreateBasicBlock(FindBasicBlockPosition(if_merge_block)) : nullptr; if (loop_merge_block) { // Add the instruction and update managers. - opt::InstructionBuilder builder( + InstructionBuilder builder( context_, loop_merge_block, - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping); + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); builder.AddBranch(if_merge_block->id()); builder.SetInsertPoint(&*loop_merge_block->begin()); cfg.RegisterBlock(loop_merge_block); def_use_mgr->AnalyzeInstDef(loop_merge_block->GetLabelInst()); // Update CFG. if_merge_block->ForEachPhiInst( - [loop_merge_block, &builder, this](ir::Instruction* phi) { - ir::Instruction* cloned = phi->Clone(context_); - builder.AddInstruction(std::unique_ptr(cloned)); + [loop_merge_block, &builder, this](Instruction* phi) { + Instruction* cloned = phi->Clone(context_); + builder.AddInstruction(std::unique_ptr(cloned)); phi->SetInOperand(0, {cloned->result_id()}); phi->SetInOperand(1, {loop_merge_block->id()}); for (uint32_t j = phi->NumInOperands() - 1; j > 1; j--) @@ -165,7 +163,7 @@ class LoopUnswitch { std::vector preds = cfg.preds(if_merge_block->id()); for (uint32_t pid : preds) { if (pid == loop_merge_block->id()) continue; - ir::BasicBlock* p_bb = cfg.block(pid); + BasicBlock* p_bb = cfg.block(pid); p_bb->ForEachSuccessorLabel( [if_merge_block, loop_merge_block](uint32_t* id) { if (*id == if_merge_block->id()) *id = loop_merge_block->id(); @@ -174,7 +172,7 @@ class LoopUnswitch { } cfg.RemoveNonExistingEdges(if_merge_block->id()); // Update loop descriptor. - if (ir::Loop* ploop = loop_->GetParent()) { + if (Loop* ploop = loop_->GetParent()) { ploop->AddBasicBlock(loop_merge_block); loop_desc_.SetBasicBlockToLoop(loop_merge_block->id(), ploop); } @@ -199,20 +197,20 @@ class LoopUnswitch { // for the constant branch. //////////////////////////////////////////////////////////////////////////// - ir::BasicBlock* if_block = loop_->GetPreHeaderBlock(); + BasicBlock* if_block = loop_->GetPreHeaderBlock(); // If this preheader is the parent loop header, // we need to create a dedicated block for the if. - ir::BasicBlock* loop_pre_header = + BasicBlock* loop_pre_header = CreateBasicBlock(++FindBasicBlockPosition(if_block)); - opt::InstructionBuilder(context_, loop_pre_header, - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping) + InstructionBuilder( + context_, loop_pre_header, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping) .AddBranch(loop_->GetHeaderBlock()->id()); if_block->tail()->SetInOperand(0, {loop_pre_header->id()}); // Update loop descriptor. - if (ir::Loop* ploop = loop_desc_[if_block]) { + if (Loop* ploop = loop_desc_[if_block]) { ploop->AddBasicBlock(loop_pre_header); loop_desc_.SetBasicBlockToLoop(loop_pre_header->id(), ploop); } @@ -224,7 +222,7 @@ class LoopUnswitch { cfg.RemoveNonExistingEdges(loop_->GetHeaderBlock()->id()); loop_->GetHeaderBlock()->ForEachPhiInst( - [loop_pre_header, if_block](ir::Instruction* phi) { + [loop_pre_header, if_block](Instruction* phi) { phi->ForEachInId([loop_pre_header, if_block](uint32_t* id) { if (*id == if_block->id()) { *id = loop_pre_header->id(); @@ -260,9 +258,9 @@ class LoopUnswitch { // - Specialize the loop // ///////////////////////////// - ir::Instruction* iv_condition = &*switch_block_->tail(); + Instruction* iv_condition = &*switch_block_->tail(); SpvOp iv_opcode = iv_condition->opcode(); - ir::Instruction* condition = + Instruction* condition = def_use_mgr->GetDef(iv_condition->GetOperand(0).words[0]); analysis::ConstantManager* cst_mgr = context_->get_constant_mgr(); @@ -271,10 +269,10 @@ class LoopUnswitch { // Build the list of value for which we need to clone and specialize the // loop. - std::vector> constant_branch; + std::vector> constant_branch; // Special case for the original loop - ir::Instruction* original_loop_constant_value; - ir::BasicBlock* original_loop_target; + Instruction* original_loop_constant_value; + BasicBlock* original_loop_target; if (iv_opcode == SpvOpBranchConditional) { constant_branch.emplace_back( cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {0})), @@ -309,13 +307,13 @@ class LoopUnswitch { } for (auto& specialisation_pair : constant_branch) { - ir::Instruction* specialisation_value = specialisation_pair.first; + Instruction* specialisation_value = specialisation_pair.first; ////////////////////////////////////////////////////////// // Step 3: Duplicate |loop_|. ////////////////////////////////////////////////////////// LoopUtils::LoopCloningResult clone_result; - ir::Loop* cloned_loop = + Loop* cloned_loop = loop_utils.CloneLoop(&clone_result, ordered_loop_blocks_); specialisation_pair.second = cloned_loop->GetPreHeaderBlock(); @@ -327,11 +325,11 @@ class LoopUnswitch { std::unordered_set dead_blocks; std::unordered_set unreachable_merges; SimplifyLoop( - ir::make_range( - ir::UptrVectorIterator( - &clone_result.cloned_bb_, clone_result.cloned_bb_.begin()), - ir::UptrVectorIterator( - &clone_result.cloned_bb_, clone_result.cloned_bb_.end())), + make_range( + UptrVectorIterator(&clone_result.cloned_bb_, + clone_result.cloned_bb_.begin()), + UptrVectorIterator(&clone_result.cloned_bb_, + clone_result.cloned_bb_.end())), cloned_loop, condition, specialisation_value, &dead_blocks); // We tagged dead blocks, create the loop before we invalidate any basic @@ -339,8 +337,8 @@ class LoopUnswitch { cloned_loop = CleanLoopNest(cloned_loop, dead_blocks, &unreachable_merges); CleanUpCFG( - ir::UptrVectorIterator( - &clone_result.cloned_bb_, clone_result.cloned_bb_.begin()), + UptrVectorIterator(&clone_result.cloned_bb_, + clone_result.cloned_bb_.begin()), dead_blocks, unreachable_merges); /////////////////////////////////////////////////////////// @@ -348,10 +346,10 @@ class LoopUnswitch { /////////////////////////////////////////////////////////// for (uint32_t merge_bb_id : if_merging_blocks) { - ir::BasicBlock* merge = context_->cfg()->block(merge_bb_id); + BasicBlock* merge = context_->cfg()->block(merge_bb_id); // We are in LCSSA so we only care about phi instructions. merge->ForEachPhiInst([is_from_original_loop, &dead_blocks, - &clone_result](ir::Instruction* phi) { + &clone_result](Instruction* phi) { uint32_t num_in_operands = phi->NumInOperands(); for (uint32_t i = 0; i < num_in_operands; i += 2) { uint32_t pred = phi->GetSingleWordInOperand(i + 1); @@ -382,11 +380,11 @@ class LoopUnswitch { { std::unordered_set dead_blocks; std::unordered_set unreachable_merges; - SimplifyLoop(ir::make_range(function_->begin(), function_->end()), loop_, + SimplifyLoop(make_range(function_->begin(), function_->end()), loop_, condition, original_loop_constant_value, &dead_blocks); for (uint32_t merge_bb_id : if_merging_blocks) { - ir::BasicBlock* merge = context_->cfg()->block(merge_bb_id); + BasicBlock* merge = context_->cfg()->block(merge_bb_id); // LCSSA, so we only care about phi instructions. // If we the phi is reduced to a single incoming branch, do not // propagate it to preserve LCSSA. @@ -417,7 +415,7 @@ class LoopUnswitch { // Delete the old jump context_->KillInst(&*if_block->tail()); - opt::InstructionBuilder builder(context_, if_block); + InstructionBuilder builder(context_, if_block); if (iv_opcode == SpvOpBranchConditional) { assert(constant_branch.size() == 1); builder.AddConditionalBranch( @@ -425,7 +423,7 @@ class LoopUnswitch { constant_branch[0].second->id(), if_merge_block ? if_merge_block->id() : kInvalidId); } else { - std::vector, uint32_t>> targets; + std::vector> targets; for (auto& t : constant_branch) { targets.emplace_back(t.first->GetInOperand(0).words, t.second->id()); } @@ -439,7 +437,7 @@ class LoopUnswitch { ordered_loop_blocks_.clear(); context_->InvalidateAnalysesExceptFor( - ir::IRContext::Analysis::kAnalysisLoopAnalysis); + IRContext::Analysis::kAnalysisLoopAnalysis); } // Returns true if the unswitch killed the original |loop_|. @@ -447,37 +445,37 @@ class LoopUnswitch { private: using ValueMapTy = std::unordered_map; - using BlockMapTy = std::unordered_map; + using BlockMapTy = std::unordered_map; - ir::Function* function_; - ir::Loop* loop_; - ir::LoopDescriptor& loop_desc_; - ir::IRContext* context_; + Function* function_; + Loop* loop_; + LoopDescriptor& loop_desc_; + IRContext* context_; - ir::BasicBlock* switch_block_; + BasicBlock* switch_block_; // Map between instructions and if they are dynamically uniform. std::unordered_map dynamically_uniform_; // The loop basic blocks in structured order. - std::vector ordered_loop_blocks_; + std::vector ordered_loop_blocks_; // Returns the next usable id for the context. uint32_t TakeNextId() { return context_->TakeNextId(); } // Patches |bb|'s phi instruction by removing incoming value from unexisting // or tagged as dead branches. - void PatchPhis(ir::BasicBlock* bb, + void PatchPhis(BasicBlock* bb, const std::unordered_set& dead_blocks, bool preserve_phi) { - ir::CFG& cfg = *context_->cfg(); + CFG& cfg = *context_->cfg(); - std::vector phi_to_kill; + std::vector phi_to_kill; const std::vector& bb_preds = cfg.preds(bb->id()); auto is_branch_dead = [&bb_preds, &dead_blocks](uint32_t id) { return dead_blocks.count(id) || std::find(bb_preds.begin(), bb_preds.end(), id) == bb_preds.end(); }; bb->ForEachPhiInst([&phi_to_kill, &is_branch_dead, preserve_phi, - this](ir::Instruction* insn) { + this](Instruction* insn) { uint32_t i = 0; while (i < insn->NumInOperands()) { uint32_t incoming_id = insn->GetSingleWordInOperand(i + 1); @@ -498,7 +496,7 @@ class LoopUnswitch { insn->GetSingleWordInOperand(0)); } }); - for (ir::Instruction* insn : phi_to_kill) { + for (Instruction* insn : phi_to_kill) { context_->KillInst(insn); } } @@ -506,20 +504,20 @@ class LoopUnswitch { // Removes any block that is tagged as dead, if the block is in // |unreachable_merges| then all block's instructions are replaced by a // OpUnreachable. - void CleanUpCFG(ir::UptrVectorIterator bb_it, + void CleanUpCFG(UptrVectorIterator bb_it, const std::unordered_set& dead_blocks, const std::unordered_set& unreachable_merges) { - ir::CFG& cfg = *context_->cfg(); + CFG& cfg = *context_->cfg(); while (bb_it != bb_it.End()) { - ir::BasicBlock& bb = *bb_it; + BasicBlock& bb = *bb_it; if (unreachable_merges.count(bb.id())) { if (bb.begin() != bb.tail() || bb.terminator()->opcode() != SpvOpUnreachable) { // Make unreachable, but leave the label. bb.KillAllInsts(false); - opt::InstructionBuilder(context_, &bb).AddUnreachable(); + InstructionBuilder(context_, &bb).AddUnreachable(); cfg.RemoveNonExistingEdges(bb.id()); } ++bb_it; @@ -537,7 +535,7 @@ class LoopUnswitch { // Return true if |c_inst| is a Boolean constant and set |cond_val| with the // value that |c_inst| - bool GetConstCondition(const ir::Instruction* c_inst, bool* cond_val) { + bool GetConstCondition(const Instruction* c_inst, bool* cond_val) { bool cond_is_const; switch (c_inst->opcode()) { case SpvOpConstantFalse: { @@ -564,35 +562,36 @@ class LoopUnswitch { // - |loop| must be in the LCSSA form; // - |cst_value| must be constant or null (to represent the default target // of an OpSwitch). - void SimplifyLoop( - ir::IteratorRange> block_range, - ir::Loop* loop, ir::Instruction* to_version_insn, - ir::Instruction* cst_value, std::unordered_set* dead_blocks) { - ir::CFG& cfg = *context_->cfg(); + void SimplifyLoop(IteratorRange> block_range, + Loop* loop, Instruction* to_version_insn, + Instruction* cst_value, + std::unordered_set* dead_blocks) { + CFG& cfg = *context_->cfg(); analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); std::function ignore_node; ignore_node = [loop](uint32_t bb_id) { return !loop->IsInsideLoop(bb_id); }; - std::vector> use_list; - def_use_mgr->ForEachUse( - to_version_insn, [&use_list, &ignore_node, this]( - ir::Instruction* inst, uint32_t operand_index) { - ir::BasicBlock* bb = context_->get_instr_block(inst); + std::vector> use_list; + def_use_mgr->ForEachUse(to_version_insn, + [&use_list, &ignore_node, this]( + Instruction* inst, uint32_t operand_index) { + BasicBlock* bb = context_->get_instr_block(inst); - if (!bb || ignore_node(bb->id())) { - // Out of the loop, the specialization does not apply any more. - return; - } - use_list.emplace_back(inst, operand_index); - }); + if (!bb || ignore_node(bb->id())) { + // Out of the loop, the specialization does not + // apply any more. + return; + } + use_list.emplace_back(inst, operand_index); + }); // First pass: inject the specialized value into the loop (and only the // loop). for (auto use : use_list) { - ir::Instruction* inst = use.first; + Instruction* inst = use.first; uint32_t operand_index = use.second; - ir::BasicBlock* bb = context_->get_instr_block(inst); + BasicBlock* bb = context_->get_instr_block(inst); // If it is not a branch, simply inject the value. if (!inst->IsBranch()) { @@ -628,9 +627,9 @@ class LoopUnswitch { live_target = inst->GetSingleWordInOperand(1); if (cst_value) { if (!cst_value->IsConstant()) break; - const ir::Operand& cst = cst_value->GetInOperand(0); + const Operand& cst = cst_value->GetInOperand(0); for (uint32_t i = 2; i < inst->NumInOperands(); i += 2) { - const ir::Operand& literal = inst->GetInOperand(i); + const Operand& literal = inst->GetInOperand(i); if (literal == cst) { live_target = inst->GetSingleWordInOperand(i + 1); break; @@ -649,13 +648,11 @@ class LoopUnswitch { } if (live_target != 0) { // Check for the presence of the merge block. - if (ir::Instruction* merge = bb->GetMergeInst()) - context_->KillInst(merge); + if (Instruction* merge = bb->GetMergeInst()) context_->KillInst(merge); context_->KillInst(&*bb->tail()); - opt::InstructionBuilder builder( - context_, bb, - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping); + InstructionBuilder builder(context_, bb, + IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping); builder.AddBranch(live_target); } } @@ -663,7 +660,7 @@ class LoopUnswitch { // Go through the loop basic block and tag all blocks that are obviously // dead. std::unordered_set visited; - for (ir::BasicBlock& bb : block_range) { + for (BasicBlock& bb : block_range) { if (ignore_node(bb.id())) continue; visited.insert(bb.id()); @@ -678,12 +675,12 @@ class LoopUnswitch { } if (!has_live_pred) { dead_blocks->insert(bb.id()); - const ir::BasicBlock& cbb = bb; + const BasicBlock& cbb = bb; // Patch the phis for any back-edge. cbb.ForEachSuccessorLabel( [dead_blocks, &visited, &cfg, this](uint32_t id) { if (!visited.count(id) || dead_blocks->count(id)) return; - ir::BasicBlock* succ = cfg.block(id); + BasicBlock* succ = cfg.block(id); PatchPhis(succ, *dead_blocks, false); }); continue; @@ -695,7 +692,7 @@ class LoopUnswitch { // Returns true if the header is not reachable or tagged as dead or if we // never loop back. - bool IsLoopDead(ir::BasicBlock* header, ir::BasicBlock* latch, + bool IsLoopDead(BasicBlock* header, BasicBlock* latch, const std::unordered_set& dead_blocks) { if (!header || dead_blocks.count(header->id())) return true; if (!latch || dead_blocks.count(latch->id())) return true; @@ -715,15 +712,14 @@ class LoopUnswitch { // |unreachable_merges|. // The function returns the pointer to |loop| or nullptr if the loop was // killed. - ir::Loop* CleanLoopNest(ir::Loop* loop, - const std::unordered_set& dead_blocks, - std::unordered_set* unreachable_merges) { + Loop* CleanLoopNest(Loop* loop, + const std::unordered_set& dead_blocks, + std::unordered_set* unreachable_merges) { // This represent the pair of dead loop and nearest alive parent (nullptr if // no parent). - std::unordered_map dead_loops; - auto get_parent = [&dead_loops](ir::Loop* l) -> ir::Loop* { - std::unordered_map::iterator it = - dead_loops.find(l); + std::unordered_map dead_loops; + auto get_parent = [&dead_loops](Loop* l) -> Loop* { + std::unordered_map::iterator it = dead_loops.find(l); if (it != dead_loops.end()) return it->second; return nullptr; }; @@ -731,20 +727,21 @@ class LoopUnswitch { bool is_main_loop_dead = IsLoopDead(loop->GetHeaderBlock(), loop->GetLatchBlock(), dead_blocks); if (is_main_loop_dead) { - if (ir::Instruction* merge = loop->GetHeaderBlock()->GetLoopMergeInst()) { + if (Instruction* merge = loop->GetHeaderBlock()->GetLoopMergeInst()) { context_->KillInst(merge); } dead_loops[loop] = loop->GetParent(); - } else + } else { dead_loops[loop] = loop; + } + // For each loop, check if we killed it. If we did, find a suitable parent // for its children. - for (ir::Loop& sub_loop : - ir::make_range(++opt::TreeDFIterator(loop), - opt::TreeDFIterator())) { + for (Loop& sub_loop : + make_range(++TreeDFIterator(loop), TreeDFIterator())) { if (IsLoopDead(sub_loop.GetHeaderBlock(), sub_loop.GetLatchBlock(), dead_blocks)) { - if (ir::Instruction* merge = + if (Instruction* merge = sub_loop.GetHeaderBlock()->GetLoopMergeInst()) { context_->KillInst(merge); } @@ -764,7 +761,7 @@ class LoopUnswitch { // Remove dead blocks from live loops. for (uint32_t bb_id : dead_blocks) { - ir::Loop* l = loop_desc_[bb_id]; + Loop* l = loop_desc_[bb_id]; if (l) { l->RemoveBasicBlock(bb_id); loop_desc_.ForgetBasicBlock(bb_id); @@ -773,8 +770,8 @@ class LoopUnswitch { std::for_each( dead_loops.begin(), dead_loops.end(), - [&loop, this]( - std::unordered_map::iterator::reference it) { + [&loop, + this](std::unordered_map::iterator::reference it) { if (it.first == loop) loop = nullptr; loop_desc_.RemoveLoop(it.first); }); @@ -784,7 +781,7 @@ class LoopUnswitch { // Returns true if |var| is dynamically uniform. // Note: this is currently approximated as uniform. - bool IsDynamicallyUniform(ir::Instruction* var, const ir::BasicBlock* entry, + bool IsDynamicallyUniform(Instruction* var, const BasicBlock* entry, const DominatorTree& post_dom_tree) { assert(post_dom_tree.IsPostDominator()); analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); @@ -799,7 +796,7 @@ class LoopUnswitch { is_uniform = false; dec_mgr->WhileEachDecoration(var->result_id(), SpvDecorationUniform, - [&is_uniform](const ir::Instruction&) { + [&is_uniform](const Instruction&) { is_uniform = true; return false; }); @@ -807,7 +804,7 @@ class LoopUnswitch { return is_uniform; } - ir::BasicBlock* parent = context_->get_instr_block(var); + BasicBlock* parent = context_->get_instr_block(var); if (!parent) { return is_uniform = true; } @@ -818,7 +815,7 @@ class LoopUnswitch { if (var->opcode() == SpvOpLoad) { const uint32_t PtrTypeId = def_use_mgr->GetDef(var->GetSingleWordInOperand(0))->type_id(); - const ir::Instruction* PtrTypeInst = def_use_mgr->GetDef(PtrTypeId); + const Instruction* PtrTypeInst = def_use_mgr->GetDef(PtrTypeId); uint32_t storage_class = PtrTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx); if (storage_class != SpvStorageClassUniform && @@ -839,50 +836,45 @@ class LoopUnswitch { } // Returns true if |insn| is constant and dynamically uniform within the loop. - bool IsConditionLoopInvariant(ir::Instruction* insn) { + bool IsConditionLoopInvariant(Instruction* insn) { assert(insn->IsBranch()); assert(insn->opcode() != SpvOpBranch); analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); - ir::Instruction* condition = - def_use_mgr->GetDef(insn->GetOperand(0).words[0]); + Instruction* condition = def_use_mgr->GetDef(insn->GetOperand(0).words[0]); return !loop_->IsInsideLoop(condition) && IsDynamicallyUniform( condition, function_->entry().get(), - context_->GetPostDominatorAnalysis(function_, *context_->cfg()) - ->GetDomTree()); + context_->GetPostDominatorAnalysis(function_)->GetDomTree()); } }; } // namespace -Pass::Status LoopUnswitchPass::Process(ir::IRContext* c) { - InitializeProcessing(c); - +Pass::Status LoopUnswitchPass::Process() { bool modified = false; - ir::Module* module = c->module(); + Module* module = context()->module(); // Process each function in the module - for (ir::Function& f : *module) { + for (Function& f : *module) { modified |= ProcessFunction(&f); } return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } -bool LoopUnswitchPass::ProcessFunction(ir::Function* f) { +bool LoopUnswitchPass::ProcessFunction(Function* f) { bool modified = false; - std::unordered_set processed_loop; + std::unordered_set processed_loop; - ir::LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f); + LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f); bool loop_changed = true; while (loop_changed) { loop_changed = false; - for (ir::Loop& loop : - ir::make_range(++opt::TreeDFIterator( - loop_descriptor.GetDummyRootLoop()), - opt::TreeDFIterator())) { + for (Loop& loop : + make_range(++TreeDFIterator(loop_descriptor.GetDummyRootLoop()), + TreeDFIterator())) { if (processed_loop.count(&loop)) continue; processed_loop.insert(&loop); diff --git a/3rdparty/spirv-tools/source/opt/loop_unswitch_pass.h b/3rdparty/spirv-tools/source/opt/loop_unswitch_pass.h index dbe581479..3ecdd6116 100644 --- a/3rdparty/spirv-tools/source/opt/loop_unswitch_pass.h +++ b/3rdparty/spirv-tools/source/opt/loop_unswitch_pass.h @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_LOOP_UNSWITCH_PASS_H_ -#define LIBSPIRV_OPT_LOOP_UNSWITCH_PASS_H_ +#ifndef SOURCE_OPT_LOOP_UNSWITCH_PASS_H_ +#define SOURCE_OPT_LOOP_UNSWITCH_PASS_H_ -#include "opt/loop_descriptor.h" -#include "opt/pass.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -31,13 +31,13 @@ class LoopUnswitchPass : public Pass { // Processes the given |module|. Returns Status::Failure if errors occur when // processing. Returns the corresponding Status::Success if processing is // succesful to indicate whether changes have been made to the modue. - Pass::Status Process(ir::IRContext* context) override; + Pass::Status Process() override; private: - bool ProcessFunction(ir::Function* f); + bool ProcessFunction(Function* f); }; } // namespace opt } // namespace spvtools -#endif // !LIBSPIRV_OPT_LOOP_UNSWITCH_PASS_H_ +#endif // SOURCE_OPT_LOOP_UNSWITCH_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/loop_utils.cpp b/3rdparty/spirv-tools/source/opt/loop_utils.cpp index 3c99b6a35..482335f3b 100644 --- a/3rdparty/spirv-tools/source/opt/loop_utils.cpp +++ b/3rdparty/spirv-tools/source/opt/loop_utils.cpp @@ -16,24 +16,25 @@ #include #include #include +#include #include -#include "cfa.h" -#include "opt/cfg.h" -#include "opt/ir_builder.h" -#include "opt/ir_context.h" -#include "opt/loop_descriptor.h" -#include "opt/loop_utils.h" +#include "source/cfa.h" +#include "source/opt/cfg.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_utils.h" namespace spvtools { namespace opt { namespace { // Return true if |bb| is dominated by at least one block in |exits| -static inline bool DominatesAnExit( - ir::BasicBlock* bb, const std::unordered_set& exits, - const opt::DominatorTree& dom_tree) { - for (ir::BasicBlock* e_bb : exits) +static inline bool DominatesAnExit(BasicBlock* bb, + const std::unordered_set& exits, + const DominatorTree& dom_tree) { + for (BasicBlock* e_bb : exits) if (dom_tree.Dominates(bb, e_bb)) return true; return false; } @@ -47,9 +48,9 @@ static inline bool DominatesAnExit( // instruction to merge the incoming value according to exit blocks definition. class LCSSARewriter { public: - LCSSARewriter(ir::IRContext* context, const opt::DominatorTree& dom_tree, - const std::unordered_set& exit_bb, - ir::BasicBlock* merge_block) + LCSSARewriter(IRContext* context, const DominatorTree& dom_tree, + const std::unordered_set& exit_bb, + BasicBlock* merge_block) : context_(context), cfg_(context_->cfg()), dom_tree_(dom_tree), @@ -57,7 +58,7 @@ class LCSSARewriter { merge_block_id_(merge_block ? merge_block->id() : 0) {} struct UseRewriter { - explicit UseRewriter(LCSSARewriter* base, const ir::Instruction& def_insn) + explicit UseRewriter(LCSSARewriter* base, const Instruction& def_insn) : base_(base), def_insn_(def_insn) {} // Rewrites the use of |def_insn_| by the instruction |user| at the index // |operand_index| in terms of phi instruction. This recursively builds new @@ -68,8 +69,7 @@ class LCSSARewriter { // block. This operation does not update the def/use manager, instead it // records what needs to be updated. The actual update is performed by // UpdateManagers. - void RewriteUse(ir::BasicBlock* bb, ir::Instruction* user, - uint32_t operand_index) { + void RewriteUse(BasicBlock* bb, Instruction* user, uint32_t operand_index) { assert( (user->opcode() != SpvOpPhi || bb != GetParent(user)) && "The root basic block must be the incoming edge if |user| is a phi " @@ -79,7 +79,7 @@ class LCSSARewriter { "not " "phi instruction"); - ir::Instruction* new_def = GetOrBuildIncoming(bb->id()); + Instruction* new_def = GetOrBuildIncoming(bb->id()); user->SetOperand(operand_index, {new_def->result_id()}); rewritten_.insert(user); @@ -87,29 +87,28 @@ class LCSSARewriter { // In-place update of some managers (avoid full invalidation). inline void UpdateManagers() { - opt::analysis::DefUseManager* def_use_mgr = - base_->context_->get_def_use_mgr(); + analysis::DefUseManager* def_use_mgr = base_->context_->get_def_use_mgr(); // Register all new definitions. - for (ir::Instruction* insn : rewritten_) { + for (Instruction* insn : rewritten_) { def_use_mgr->AnalyzeInstDef(insn); } // Register all new uses. - for (ir::Instruction* insn : rewritten_) { + for (Instruction* insn : rewritten_) { def_use_mgr->AnalyzeInstUse(insn); } } private: // Return the basic block that |instr| belongs to. - ir::BasicBlock* GetParent(ir::Instruction* instr) { + BasicBlock* GetParent(Instruction* instr) { return base_->context_->get_instr_block(instr); } // Builds a phi instruction for the basic block |bb|. The function assumes // that |defining_blocks| contains the list of basic block that define the // usable value for each predecessor of |bb|. - inline ir::Instruction* CreatePhiInstruction( - ir::BasicBlock* bb, const std::vector& defining_blocks) { + inline Instruction* CreatePhiInstruction( + BasicBlock* bb, const std::vector& defining_blocks) { std::vector incomings; const std::vector& bb_preds = base_->cfg_->preds(bb->id()); assert(bb_preds.size() == defining_blocks.size()); @@ -118,10 +117,9 @@ class LCSSARewriter { GetOrBuildIncoming(defining_blocks[i])->result_id()); incomings.push_back(bb_preds[i]); } - opt::InstructionBuilder builder( - base_->context_, &*bb->begin(), - ir::IRContext::kAnalysisInstrToBlockMapping); - ir::Instruction* incoming_phi = + InstructionBuilder builder(base_->context_, &*bb->begin(), + IRContext::kAnalysisInstrToBlockMapping); + Instruction* incoming_phi = builder.AddPhi(def_insn_.type_id(), incomings); rewritten_.insert(incoming_phi); @@ -130,18 +128,17 @@ class LCSSARewriter { // Builds a phi instruction for the basic block |bb|, all incoming values // will be |value|. - inline ir::Instruction* CreatePhiInstruction(ir::BasicBlock* bb, - const ir::Instruction& value) { + inline Instruction* CreatePhiInstruction(BasicBlock* bb, + const Instruction& value) { std::vector incomings; const std::vector& bb_preds = base_->cfg_->preds(bb->id()); for (size_t i = 0; i < bb_preds.size(); i++) { incomings.push_back(value.result_id()); incomings.push_back(bb_preds[i]); } - opt::InstructionBuilder builder( - base_->context_, &*bb->begin(), - ir::IRContext::kAnalysisInstrToBlockMapping); - ir::Instruction* incoming_phi = + InstructionBuilder builder(base_->context_, &*bb->begin(), + IRContext::kAnalysisInstrToBlockMapping); + Instruction* incoming_phi = builder.AddPhi(def_insn_.type_id(), incomings); rewritten_.insert(incoming_phi); @@ -153,21 +150,21 @@ class LCSSARewriter { // - return the common def used by all predecessors; // - if there is no common def, then we build a new phi instr at the // beginning of |bb_id| and return this new instruction. - ir::Instruction* GetOrBuildIncoming(uint32_t bb_id) { + Instruction* GetOrBuildIncoming(uint32_t bb_id) { assert(base_->cfg_->block(bb_id) != nullptr && "Unknown basic block"); - ir::Instruction*& incoming_phi = bb_to_phi_[bb_id]; + Instruction*& incoming_phi = bb_to_phi_[bb_id]; if (incoming_phi) { return incoming_phi; } - ir::BasicBlock* bb = &*base_->cfg_->block(bb_id); + BasicBlock* bb = &*base_->cfg_->block(bb_id); // If this is an exit basic block, look if there already is an eligible // phi instruction. An eligible phi has |def_insn_| as all incoming // values. if (base_->exit_bb_.count(bb)) { // Look if there is an eligible phi in this block. - if (!bb->WhileEachPhiInst([&incoming_phi, this](ir::Instruction* phi) { + if (!bb->WhileEachPhiInst([&incoming_phi, this](Instruction* phi) { for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { if (phi->GetSingleWordInOperand(i) != def_insn_.result_id()) return true; @@ -208,9 +205,9 @@ class LCSSARewriter { } LCSSARewriter* base_; - const ir::Instruction& def_insn_; - std::unordered_map bb_to_phi_; - std::unordered_set rewritten_; + const Instruction& def_insn_; + std::unordered_map bb_to_phi_; + std::unordered_set rewritten_; }; private: @@ -226,7 +223,7 @@ class LCSSARewriter { if (defining_blocks.size()) return defining_blocks; // Check if one of the loop exit basic block dominates |bb_id|. - for (const ir::BasicBlock* e_bb : exit_bb_) { + for (const BasicBlock* e_bb : exit_bb_) { if (dom_tree_.Dominates(e_bb->id(), bb_id)) { defining_blocks.push_back(e_bb->id()); return defining_blocks; @@ -255,10 +252,10 @@ class LCSSARewriter { return defining_blocks; } - ir::IRContext* context_; - ir::CFG* cfg_; - const opt::DominatorTree& dom_tree_; - const std::unordered_set& exit_bb_; + IRContext* context_; + CFG* cfg_; + const DominatorTree& dom_tree_; + const std::unordered_set& exit_bb_; uint32_t merge_block_id_; // This map represent the set of known paths. For each key, the vector // represent the set of blocks holding the definition to be used to build the @@ -274,25 +271,25 @@ class LCSSARewriter { // Make the set |blocks| closed SSA. The set is closed SSA if all the uses // outside the set are phi instructions in exiting basic block set (hold by // |lcssa_rewriter|). -inline void MakeSetClosedSSA(ir::IRContext* context, ir::Function* function, +inline void MakeSetClosedSSA(IRContext* context, Function* function, const std::unordered_set& blocks, - const std::unordered_set& exit_bb, + const std::unordered_set& exit_bb, LCSSARewriter* lcssa_rewriter) { - ir::CFG& cfg = *context->cfg(); - opt::DominatorTree& dom_tree = - context->GetDominatorAnalysis(function, cfg)->GetDomTree(); - opt::analysis::DefUseManager* def_use_manager = context->get_def_use_mgr(); + CFG& cfg = *context->cfg(); + DominatorTree& dom_tree = + context->GetDominatorAnalysis(function)->GetDomTree(); + analysis::DefUseManager* def_use_manager = context->get_def_use_mgr(); for (uint32_t bb_id : blocks) { - ir::BasicBlock* bb = cfg.block(bb_id); + BasicBlock* bb = cfg.block(bb_id); // If bb does not dominate an exit block, then it cannot have escaping defs. if (!DominatesAnExit(bb, exit_bb, dom_tree)) continue; - for (ir::Instruction& inst : *bb) { + for (Instruction& inst : *bb) { LCSSARewriter::UseRewriter rewriter(lcssa_rewriter, inst); def_use_manager->ForEachUse( &inst, [&blocks, &rewriter, &exit_bb, context]( - ir::Instruction* use, uint32_t operand_index) { - ir::BasicBlock* use_parent = context->get_instr_block(use); + Instruction* use, uint32_t operand_index) { + BasicBlock* use_parent = context->get_instr_block(use); assert(use_parent); if (blocks.count(use_parent->id())) return; @@ -320,26 +317,25 @@ inline void MakeSetClosedSSA(ir::IRContext* context, ir::Function* function, } // namespace void LoopUtils::CreateLoopDedicatedExits() { - ir::Function* function = loop_->GetHeaderBlock()->GetParent(); - ir::LoopDescriptor& loop_desc = *context_->GetLoopDescriptor(function); - ir::CFG& cfg = *context_->cfg(); - opt::analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + Function* function = loop_->GetHeaderBlock()->GetParent(); + LoopDescriptor& loop_desc = *context_->GetLoopDescriptor(function); + CFG& cfg = *context_->cfg(); + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); - const ir::IRContext::Analysis PreservedAnalyses = - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping; + const IRContext::Analysis PreservedAnalyses = + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping; // Gathers the set of basic block that are not in this loop and have at least // one predecessor in the loop and one not in the loop. std::unordered_set exit_bb_set; loop_->GetExitBlocks(&exit_bb_set); - std::unordered_set new_loop_exits; + std::unordered_set new_loop_exits; bool made_change = false; // For each block, we create a new one that gathers all branches from // the loop and fall into the block. for (uint32_t non_dedicate_id : exit_bb_set) { - ir::BasicBlock* non_dedicate = cfg.block(non_dedicate_id); + BasicBlock* non_dedicate = cfg.block(non_dedicate_id); const std::vector& bb_pred = cfg.preds(non_dedicate_id); // Ignore the block if all the predecessors are in the loop. if (std::all_of(bb_pred.begin(), bb_pred.end(), @@ -349,23 +345,22 @@ void LoopUtils::CreateLoopDedicatedExits() { } made_change = true; - ir::Function::iterator insert_pt = function->begin(); + Function::iterator insert_pt = function->begin(); for (; insert_pt != function->end() && &*insert_pt != non_dedicate; ++insert_pt) { } assert(insert_pt != function->end() && "Basic Block not found"); // Create the dedicate exit basic block. - ir::BasicBlock& exit = *insert_pt.InsertBefore( - std::unique_ptr(new ir::BasicBlock( - std::unique_ptr(new ir::Instruction( - context_, SpvOpLabel, 0, context_->TakeNextId(), {}))))); + BasicBlock& exit = *insert_pt.InsertBefore(std::unique_ptr( + new BasicBlock(std::unique_ptr(new Instruction( + context_, SpvOpLabel, 0, context_->TakeNextId(), {}))))); exit.SetParent(function); // Redirect in loop predecessors to |exit| block. for (uint32_t exit_pred_id : bb_pred) { if (loop_->IsInsideLoop(exit_pred_id)) { - ir::BasicBlock* pred_block = cfg.block(exit_pred_id); + BasicBlock* pred_block = cfg.block(exit_pred_id); pred_block->ForEachSuccessorLabel([non_dedicate, &exit](uint32_t* id) { if (*id == non_dedicate->id()) *id = exit.id(); }); @@ -380,50 +375,50 @@ void LoopUtils::CreateLoopDedicatedExits() { def_use_mgr->AnalyzeInstDefUse(exit.GetLabelInst()); context_->set_instr_block(exit.GetLabelInst(), &exit); - opt::InstructionBuilder builder(context_, &exit, PreservedAnalyses); + InstructionBuilder builder(context_, &exit, PreservedAnalyses); // Now jump from our dedicate basic block to the old exit. // We also reset the insert point so all instructions are inserted before // the branch. builder.SetInsertPoint(builder.AddBranch(non_dedicate->id())); - non_dedicate->ForEachPhiInst([&builder, &exit, def_use_mgr, - this](ir::Instruction* phi) { - // New phi operands for this instruction. - std::vector new_phi_op; - // Phi operands for the dedicated exit block. - std::vector exit_phi_op; - for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { - uint32_t def_id = phi->GetSingleWordInOperand(i); - uint32_t incoming_id = phi->GetSingleWordInOperand(i + 1); - if (loop_->IsInsideLoop(incoming_id)) { - exit_phi_op.push_back(def_id); - exit_phi_op.push_back(incoming_id); - } else { - new_phi_op.push_back(def_id); - new_phi_op.push_back(incoming_id); - } - } + non_dedicate->ForEachPhiInst( + [&builder, &exit, def_use_mgr, this](Instruction* phi) { + // New phi operands for this instruction. + std::vector new_phi_op; + // Phi operands for the dedicated exit block. + std::vector exit_phi_op; + for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { + uint32_t def_id = phi->GetSingleWordInOperand(i); + uint32_t incoming_id = phi->GetSingleWordInOperand(i + 1); + if (loop_->IsInsideLoop(incoming_id)) { + exit_phi_op.push_back(def_id); + exit_phi_op.push_back(incoming_id); + } else { + new_phi_op.push_back(def_id); + new_phi_op.push_back(incoming_id); + } + } - // Build the new phi instruction dedicated exit block. - ir::Instruction* exit_phi = builder.AddPhi(phi->type_id(), exit_phi_op); - // Build the new incoming branch. - new_phi_op.push_back(exit_phi->result_id()); - new_phi_op.push_back(exit.id()); - // Rewrite operands. - uint32_t idx = 0; - for (; idx < new_phi_op.size(); idx++) - phi->SetInOperand(idx, {new_phi_op[idx]}); - // Remove extra operands, from last to first (more efficient). - for (uint32_t j = phi->NumInOperands() - 1; j >= idx; j--) - phi->RemoveInOperand(j); - // Update the def/use manager for this |phi|. - def_use_mgr->AnalyzeInstUse(phi); - }); + // Build the new phi instruction dedicated exit block. + Instruction* exit_phi = builder.AddPhi(phi->type_id(), exit_phi_op); + // Build the new incoming branch. + new_phi_op.push_back(exit_phi->result_id()); + new_phi_op.push_back(exit.id()); + // Rewrite operands. + uint32_t idx = 0; + for (; idx < new_phi_op.size(); idx++) + phi->SetInOperand(idx, {new_phi_op[idx]}); + // Remove extra operands, from last to first (more efficient). + for (uint32_t j = phi->NumInOperands() - 1; j >= idx; j--) + phi->RemoveInOperand(j); + // Update the def/use manager for this |phi|. + def_use_mgr->AnalyzeInstUse(phi); + }); // Update the CFG. cfg.RegisterBlock(&exit); cfg.RemoveNonExistingEdges(non_dedicate->id()); new_loop_exits.insert(&exit); // If non_dedicate is in a loop, add the new dedicated exit in that loop. - if (ir::Loop* parent_loop = loop_desc[non_dedicate]) + if (Loop* parent_loop = loop_desc[non_dedicate]) parent_loop->AddBasicBlock(&exit); } @@ -433,20 +428,20 @@ void LoopUtils::CreateLoopDedicatedExits() { if (made_change) { context_->InvalidateAnalysesExceptFor( - PreservedAnalyses | ir::IRContext::kAnalysisCFG | - ir::IRContext::Analysis::kAnalysisLoopAnalysis); + PreservedAnalyses | IRContext::kAnalysisCFG | + IRContext::Analysis::kAnalysisLoopAnalysis); } } void LoopUtils::MakeLoopClosedSSA() { CreateLoopDedicatedExits(); - ir::Function* function = loop_->GetHeaderBlock()->GetParent(); - ir::CFG& cfg = *context_->cfg(); - opt::DominatorTree& dom_tree = - context_->GetDominatorAnalysis(function, cfg)->GetDomTree(); + Function* function = loop_->GetHeaderBlock()->GetParent(); + CFG& cfg = *context_->cfg(); + DominatorTree& dom_tree = + context_->GetDominatorAnalysis(function)->GetDomTree(); - std::unordered_set exit_bb; + std::unordered_set exit_bb; { std::unordered_set exit_bb_id; loop_->GetExitBlocks(&exit_bb_id); @@ -476,27 +471,94 @@ void LoopUtils::MakeLoopClosedSSA() { } context_->InvalidateAnalysesExceptFor( - ir::IRContext::Analysis::kAnalysisDefUse | - ir::IRContext::Analysis::kAnalysisCFG | - ir::IRContext::Analysis::kAnalysisDominatorAnalysis | - ir::IRContext::Analysis::kAnalysisLoopAnalysis); + IRContext::Analysis::kAnalysisCFG | + IRContext::Analysis::kAnalysisDominatorAnalysis | + IRContext::Analysis::kAnalysisLoopAnalysis); } -ir::Loop* LoopUtils::CloneLoop( +Loop* LoopUtils::CloneLoop(LoopCloningResult* cloning_result) const { + // Compute the structured order of the loop basic blocks and store it in the + // vector ordered_loop_blocks. + std::vector ordered_loop_blocks; + loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks); + + // Clone the loop. + return CloneLoop(cloning_result, ordered_loop_blocks); +} + +Loop* LoopUtils::CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result) { + // Clone the loop. + Loop* new_loop = CloneLoop(cloning_result); + + // Create a new exit block/label for the new loop. + std::unique_ptr new_label{new Instruction( + context_, SpvOp::SpvOpLabel, 0, context_->TakeNextId(), {})}; + std::unique_ptr new_exit_bb{new BasicBlock(std::move(new_label))}; + new_exit_bb->SetParent(loop_->GetMergeBlock()->GetParent()); + + // Create an unconditional branch to the header block. + InstructionBuilder builder{context_, new_exit_bb.get()}; + builder.AddBranch(loop_->GetHeaderBlock()->id()); + + // Save the ids of the new and old merge block. + const uint32_t old_merge_block = loop_->GetMergeBlock()->id(); + const uint32_t new_merge_block = new_exit_bb->id(); + + // Replace the uses of the old merge block in the new loop with the new merge + // block. + for (std::unique_ptr& basic_block : cloning_result->cloned_bb_) { + for (Instruction& inst : *basic_block) { + // For each operand in each instruction check if it is using the old merge + // block and change it to be the new merge block. + auto replace_merge_use = [old_merge_block, + new_merge_block](uint32_t* id) { + if (*id == old_merge_block) *id = new_merge_block; + }; + inst.ForEachInOperand(replace_merge_use); + } + } + + const uint32_t old_header = loop_->GetHeaderBlock()->id(); + const uint32_t new_header = new_loop->GetHeaderBlock()->id(); + analysis::DefUseManager* def_use = context_->get_def_use_mgr(); + + def_use->ForEachUse(old_header, + [new_header, this](Instruction* inst, uint32_t operand) { + if (!this->loop_->IsInsideLoop(inst)) + inst->SetOperand(operand, {new_header}); + }); + + def_use->ForEachUse( + loop_->GetOrCreatePreHeaderBlock()->id(), + [new_merge_block, this](Instruction* inst, uint32_t operand) { + if (this->loop_->IsInsideLoop(inst)) + inst->SetOperand(operand, {new_merge_block}); + + }); + new_loop->SetMergeBlock(new_exit_bb.get()); + + new_loop->SetPreHeaderBlock(loop_->GetPreHeaderBlock()); + + // Add the new block into the cloned instructions. + cloning_result->cloned_bb_.push_back(std::move(new_exit_bb)); + + return new_loop; +} + +Loop* LoopUtils::CloneLoop( LoopCloningResult* cloning_result, - const std::vector& ordered_loop_blocks) const { + const std::vector& ordered_loop_blocks) const { analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); - std::unique_ptr new_loop = MakeUnique(context_); - if (loop_->HasParent()) new_loop->SetParent(loop_->GetParent()); + std::unique_ptr new_loop = MakeUnique(context_); - ir::CFG& cfg = *context_->cfg(); + CFG& cfg = *context_->cfg(); // Clone and place blocks in a SPIR-V compliant order (dominators first). - for (ir::BasicBlock* old_bb : ordered_loop_blocks) { + for (BasicBlock* old_bb : ordered_loop_blocks) { // For each basic block in the loop, we clone it and register the mapping // between old and new ids. - ir::BasicBlock* new_bb = old_bb->Clone(context_); + BasicBlock* new_bb = old_bb->Clone(context_); new_bb->SetParent(&function_); new_bb->GetLabelInst()->SetResultId(context_->TakeNextId()); def_use_mgr->AnalyzeInstDef(new_bb->GetLabelInst()); @@ -509,24 +571,26 @@ ir::Loop* LoopUtils::CloneLoop( if (loop_->IsInsideLoop(old_bb)) new_loop->AddBasicBlock(new_bb); - for (auto& inst : *new_bb) { - if (inst.HasResultId()) { - uint32_t old_result_id = inst.result_id(); - inst.SetResultId(context_->TakeNextId()); - cloning_result->value_map_[old_result_id] = inst.result_id(); + for (auto new_inst = new_bb->begin(), old_inst = old_bb->begin(); + new_inst != new_bb->end(); ++new_inst, ++old_inst) { + cloning_result->ptr_map_[&*new_inst] = &*old_inst; + if (new_inst->HasResultId()) { + new_inst->SetResultId(context_->TakeNextId()); + cloning_result->value_map_[old_inst->result_id()] = + new_inst->result_id(); // Only look at the defs for now, uses are not updated yet. - def_use_mgr->AnalyzeInstDef(&inst); + def_use_mgr->AnalyzeInstDef(&*new_inst); } } } // All instructions (including all labels) have been cloned, // remap instruction operands id with the new ones. - for (std::unique_ptr& bb_ref : cloning_result->cloned_bb_) { - ir::BasicBlock* bb = bb_ref.get(); + for (std::unique_ptr& bb_ref : cloning_result->cloned_bb_) { + BasicBlock* bb = bb_ref.get(); - for (ir::Instruction& insn : *bb) { + for (Instruction& insn : *bb) { insn.ForEachInId([cloning_result](uint32_t* old_id) { // If the operand is defined in the loop, remap the id. auto id_it = cloning_result->value_map_.find(*old_id); @@ -548,32 +612,31 @@ ir::Loop* LoopUtils::CloneLoop( } void LoopUtils::PopulateLoopNest( - ir::Loop* new_loop, const LoopCloningResult& cloning_result) const { - std::unordered_map loop_mapping; + Loop* new_loop, const LoopCloningResult& cloning_result) const { + std::unordered_map loop_mapping; loop_mapping[loop_] = new_loop; if (loop_->HasParent()) loop_->GetParent()->AddNestedLoop(new_loop); PopulateLoopDesc(new_loop, loop_, cloning_result); - for (ir::Loop& sub_loop : - ir::make_range(++opt::TreeDFIterator(loop_), - opt::TreeDFIterator())) { - ir::Loop* cloned = new ir::Loop(context_); - if (ir::Loop* parent = loop_mapping[sub_loop.GetParent()]) + for (Loop& sub_loop : + make_range(++TreeDFIterator(loop_), TreeDFIterator())) { + Loop* cloned = new Loop(context_); + if (Loop* parent = loop_mapping[sub_loop.GetParent()]) parent->AddNestedLoop(cloned); loop_mapping[&sub_loop] = cloned; PopulateLoopDesc(cloned, &sub_loop, cloning_result); } - loop_desc_->AddLoopNest(std::unique_ptr(new_loop)); + loop_desc_->AddLoopNest(std::unique_ptr(new_loop)); } // Populates |new_loop| descriptor according to |old_loop|'s one. void LoopUtils::PopulateLoopDesc( - ir::Loop* new_loop, ir::Loop* old_loop, + Loop* new_loop, Loop* old_loop, const LoopCloningResult& cloning_result) const { for (uint32_t bb_id : old_loop->GetBlocks()) { - ir::BasicBlock* bb = cloning_result.old_to_new_bb_.at(bb_id); + BasicBlock* bb = cloning_result.old_to_new_bb_.at(bb_id); new_loop->AddBasicBlock(bb); } new_loop->SetHeaderBlock( @@ -581,12 +644,15 @@ void LoopUtils::PopulateLoopDesc( if (old_loop->GetLatchBlock()) new_loop->SetLatchBlock( cloning_result.old_to_new_bb_.at(old_loop->GetLatchBlock()->id())); + if (old_loop->GetContinueBlock()) + new_loop->SetContinueBlock( + cloning_result.old_to_new_bb_.at(old_loop->GetContinueBlock()->id())); if (old_loop->GetMergeBlock()) { auto it = cloning_result.old_to_new_bb_.find(old_loop->GetMergeBlock()->id()); - ir::BasicBlock* bb = it != cloning_result.old_to_new_bb_.end() - ? it->second - : old_loop->GetMergeBlock(); + BasicBlock* bb = it != cloning_result.old_to_new_bb_.end() + ? it->second + : old_loop->GetMergeBlock(); new_loop->SetMergeBlock(bb); } if (old_loop->GetPreHeaderBlock()) { @@ -598,5 +664,26 @@ void LoopUtils::PopulateLoopDesc( } } +// Class to gather some metrics about a region of interest. +void CodeMetrics::Analyze(const Loop& loop) { + CFG& cfg = *loop.GetContext()->cfg(); + + roi_size_ = 0; + block_sizes_.clear(); + + for (uint32_t id : loop.GetBlocks()) { + const BasicBlock* bb = cfg.block(id); + size_t bb_size = 0; + bb->ForEachInst([&bb_size](const Instruction* insn) { + if (insn->opcode() == SpvOpLabel) return; + if (insn->IsNop()) return; + if (insn->opcode() == SpvOpPhi) return; + bb_size++; + }); + block_sizes_[bb->id()] = bb_size; + roi_size_ += bb_size; + } +} + } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/loop_utils.h b/3rdparty/spirv-tools/source/opt/loop_utils.h index a50be3321..a4e61900b 100644 --- a/3rdparty/spirv-tools/source/opt/loop_utils.h +++ b/3rdparty/spirv-tools/source/opt/loop_utils.h @@ -14,16 +14,32 @@ #ifndef SOURCE_OPT_LOOP_UTILS_H_ #define SOURCE_OPT_LOOP_UTILS_H_ + #include #include +#include #include -#include "opt/ir_context.h" -#include "opt/loop_descriptor.h" + +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" namespace spvtools { namespace opt { +// Class to gather some metrics about a Region Of Interest (ROI). +// So far it counts the number of instructions in a ROI (excluding debug +// and label instructions) per basic block and in total. +struct CodeMetrics { + void Analyze(const Loop& loop); + + // The number of instructions per basic block in the ROI. + std::unordered_map block_sizes_; + + // Number of instruction in the ROI. + size_t roi_size_; +}; + // LoopUtils is used to encapsulte loop optimizations and from the passes which // use them. Any pass which needs a loop optimization should do it through this // or through a pass which is using this. @@ -32,7 +48,10 @@ class LoopUtils { // Holds a auxiliary results of the loop cloning procedure. struct LoopCloningResult { using ValueMapTy = std::unordered_map; - using BlockMapTy = std::unordered_map; + using BlockMapTy = std::unordered_map; + using PtrMap = std::unordered_map; + + PtrMap ptr_map_; // Mapping between the original loop ids and the new one. ValueMapTy value_map_; @@ -41,10 +60,10 @@ class LoopUtils { // Mapping between the cloned loop blocks to original one. BlockMapTy new_to_old_bb_; // List of cloned basic block. - std::vector> cloned_bb_; + std::vector> cloned_bb_; }; - LoopUtils(ir::IRContext* context, ir::Loop* loop) + LoopUtils(IRContext* context, Loop* loop) : context_(context), loop_desc_( context->GetLoopDescriptor(loop->GetHeaderBlock()->GetParent())), @@ -95,9 +114,14 @@ class LoopUtils { // The function preserves the def/use, cfg and instr to block analyses. // The cloned loop nest will be added to the loop descriptor and will have // ownership. - ir::Loop* CloneLoop( - LoopCloningResult* cloning_result, - const std::vector& ordered_loop_blocks) const; + Loop* CloneLoop(LoopCloningResult* cloning_result, + const std::vector& ordered_loop_blocks) const; + // Clone |loop_| and remap its instructions, as above. Overload to compute + // loop block ordering within method rather than taking in as parameter. + Loop* CloneLoop(LoopCloningResult* cloning_result) const; + + // Clone the |loop_| and make the new loop branch to the second loop on exit. + Loop* CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result); // Perfom a partial unroll of |loop| by given |factor|. This will copy the // body of the loop |factor| times. So a |factor| of one would give a new loop @@ -129,26 +153,26 @@ class LoopUtils { void Finalize(); // Returns the context associate to |loop_|. - ir::IRContext* GetContext() { return context_; } + IRContext* GetContext() { return context_; } // Returns the loop descriptor owning |loop_|. - ir::LoopDescriptor* GetLoopDescriptor() { return loop_desc_; } + LoopDescriptor* GetLoopDescriptor() { return loop_desc_; } // Returns the loop on which the object operates on. - ir::Loop* GetLoop() const { return loop_; } + Loop* GetLoop() const { return loop_; } // Returns the function that |loop_| belong to. - ir::Function* GetFunction() const { return &function_; } + Function* GetFunction() const { return &function_; } private: - ir::IRContext* context_; - ir::LoopDescriptor* loop_desc_; - ir::Loop* loop_; - ir::Function& function_; + IRContext* context_; + LoopDescriptor* loop_desc_; + Loop* loop_; + Function& function_; // Populates the loop nest of |new_loop| according to |loop_| nest. - void PopulateLoopNest(ir::Loop* new_loop, + void PopulateLoopNest(Loop* new_loop, const LoopCloningResult& cloning_result) const; // Populates |new_loop| descriptor according to |old_loop|'s one. - void PopulateLoopDesc(ir::Loop* new_loop, ir::Loop* old_loop, + void PopulateLoopDesc(Loop* new_loop, Loop* old_loop, const LoopCloningResult& cloning_result) const; }; diff --git a/3rdparty/spirv-tools/source/opt/mem_pass.cpp b/3rdparty/spirv-tools/source/opt/mem_pass.cpp index 13364da0c..c65e04938 100644 --- a/3rdparty/spirv-tools/source/opt/mem_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/mem_pass.cpp @@ -14,13 +14,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mem_pass.h" +#include "source/opt/mem_pass.h" -#include "basic_block.h" -#include "cfa.h" -#include "dominator_analysis.h" -#include "ir_context.h" -#include "iterator.h" +#include +#include +#include + +#include "source/cfa.h" +#include "source/opt/basic_block.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" namespace spvtools { namespace opt { @@ -28,15 +32,12 @@ namespace opt { namespace { const uint32_t kCopyObjectOperandInIdx = 0; -const uint32_t kLoopMergeMergeBlockIdInIdx = 0; -const uint32_t kStoreValIdInIdx = 1; const uint32_t kTypePointerStorageClassInIdx = 0; const uint32_t kTypePointerTypeIdInIdx = 1; -const uint32_t kVariableInitIdInIdx = 1; } // namespace -bool MemPass::IsBaseTargetType(const ir::Instruction* typeInst) const { +bool MemPass::IsBaseTargetType(const Instruction* typeInst) const { switch (typeInst->opcode()) { case SpvOpTypeInt: case SpvOpTypeFloat: @@ -54,7 +55,7 @@ bool MemPass::IsBaseTargetType(const ir::Instruction* typeInst) const { return false; } -bool MemPass::IsTargetType(const ir::Instruction* typeInst) const { +bool MemPass::IsTargetType(const Instruction* typeInst) const { if (IsBaseTargetType(typeInst)) return true; if (typeInst->opcode() == SpvOpTypeArray) { if (!IsTargetType( @@ -66,7 +67,7 @@ bool MemPass::IsTargetType(const ir::Instruction* typeInst) const { if (typeInst->opcode() != SpvOpTypeStruct) return false; // All struct members must be math type return typeInst->WhileEachInId([this](const uint32_t* tid) { - ir::Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid); + Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid); if (!IsTargetType(compTypeInst)) return false; return true; }); @@ -78,7 +79,7 @@ bool MemPass::IsNonPtrAccessChain(const SpvOp opcode) const { bool MemPass::IsPtr(uint32_t ptrId) { uint32_t varId = ptrId; - ir::Instruction* ptrInst = get_def_use_mgr()->GetDef(varId); + Instruction* ptrInst = get_def_use_mgr()->GetDef(varId); while (ptrInst->opcode() == SpvOpCopyObject) { varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx); ptrInst = get_def_use_mgr()->GetDef(varId); @@ -87,14 +88,14 @@ bool MemPass::IsPtr(uint32_t ptrId) { if (op == SpvOpVariable || IsNonPtrAccessChain(op)) return true; if (op != SpvOpFunctionParameter) return false; const uint32_t varTypeId = ptrInst->type_id(); - const ir::Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); + const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); return varTypeInst->opcode() == SpvOpTypePointer; } -ir::Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) { +Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) { *varId = ptrId; - ir::Instruction* ptrInst = get_def_use_mgr()->GetDef(*varId); - ir::Instruction* varInst; + Instruction* ptrInst = get_def_use_mgr()->GetDef(*varId); + Instruction* varInst; if (ptrInst->opcode() != SpvOpVariable && ptrInst->opcode() != SpvOpFunctionParameter) { @@ -116,7 +117,7 @@ ir::Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) { return ptrInst; } -ir::Instruction* MemPass::GetPtr(ir::Instruction* ip, uint32_t* varId) { +Instruction* MemPass::GetPtr(Instruction* ip, uint32_t* varId) { assert(ip->opcode() == SpvOpStore || ip->opcode() == SpvOpLoad || ip->opcode() == SpvOpImageTexelPointer); @@ -126,7 +127,7 @@ ir::Instruction* MemPass::GetPtr(ir::Instruction* ip, uint32_t* varId) { } bool MemPass::HasOnlyNamesAndDecorates(uint32_t id) const { - return get_def_use_mgr()->WhileEachUser(id, [this](ir::Instruction* user) { + return get_def_use_mgr()->WhileEachUser(id, [this](Instruction* user) { SpvOp op = user->opcode(); if (op != SpvOpName && !IsNonTypeDecorate(op)) { return false; @@ -135,13 +136,12 @@ bool MemPass::HasOnlyNamesAndDecorates(uint32_t id) const { }); } -void MemPass::KillAllInsts(ir::BasicBlock* bp, bool killLabel) { +void MemPass::KillAllInsts(BasicBlock* bp, bool killLabel) { bp->KillAllInsts(killLabel); } bool MemPass::HasLoads(uint32_t varId) const { - return !get_def_use_mgr()->WhileEachUser(varId, [this]( - ir::Instruction* user) { + return !get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) { SpvOp op = user->opcode(); // TODO(): The following is slightly conservative. Could be // better handling of non-store/name. @@ -157,12 +157,12 @@ bool MemPass::HasLoads(uint32_t varId) const { } bool MemPass::IsLiveVar(uint32_t varId) const { - const ir::Instruction* varInst = get_def_use_mgr()->GetDef(varId); + const Instruction* varInst = get_def_use_mgr()->GetDef(varId); // assume live if not a variable eg. function parameter if (varInst->opcode() != SpvOpVariable) return true; // non-function scope vars are live const uint32_t varTypeId = varInst->type_id(); - const ir::Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); + const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) != SpvStorageClassFunction) return true; @@ -170,20 +170,8 @@ bool MemPass::IsLiveVar(uint32_t varId) const { return HasLoads(varId); } -bool MemPass::IsLiveStore(ir::Instruction* storeInst) { - // get store's variable - uint32_t varId; - (void)GetPtr(storeInst, &varId); - if (varId == 0) { - // If we do not know which variable we are accessing, assume the store is - // live. - return true; - } - return IsLiveVar(varId); -} - -void MemPass::AddStores(uint32_t ptr_id, std::queue* insts) { - get_def_use_mgr()->ForEachUser(ptr_id, [this, insts](ir::Instruction* user) { +void MemPass::AddStores(uint32_t ptr_id, std::queue* insts) { + get_def_use_mgr()->ForEachUser(ptr_id, [this, insts](Instruction* user) { SpvOp op = user->opcode(); if (IsNonPtrAccessChain(op)) { AddStores(user->result_id(), insts); @@ -193,12 +181,12 @@ void MemPass::AddStores(uint32_t ptr_id, std::queue* insts) { }); } -void MemPass::DCEInst(ir::Instruction* inst, - const function& call_back) { - std::queue deadInsts; +void MemPass::DCEInst(Instruction* inst, + const std::function& call_back) { + std::queue deadInsts; deadInsts.push(inst); while (!deadInsts.empty()) { - ir::Instruction* di = deadInsts.front(); + Instruction* di = deadInsts.front(); // Don't delete labels if (di->opcode() == SpvOpLabel) { deadInsts.pop(); @@ -218,7 +206,7 @@ void MemPass::DCEInst(ir::Instruction* inst, // to the dead instruction queue. for (auto id : ids) if (HasOnlyNamesAndDecorates(id)) { - ir::Instruction* odi = get_def_use_mgr()->GetDef(id); + Instruction* odi = get_def_use_mgr()->GetDef(id); if (context()->IsCombinatorInstruction(odi)) deadInsts.push(odi); } // if a load was deleted and it was the variable's @@ -231,8 +219,7 @@ void MemPass::DCEInst(ir::Instruction* inst, MemPass::MemPass() {} bool MemPass::HasOnlySupportedRefs(uint32_t varId) { - if (supported_ref_vars_.find(varId) != supported_ref_vars_.end()) return true; - return get_def_use_mgr()->WhileEachUser(varId, [this](ir::Instruction* user) { + return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) { SpvOp op = user->opcode(); if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName && !IsNonTypeDecorate(op)) { @@ -242,264 +229,18 @@ bool MemPass::HasOnlySupportedRefs(uint32_t varId) { }); } -void MemPass::InitSSARewrite(ir::Function* func) { - // Clear collections. - visitedBlocks_.clear(); - block_defs_map_.clear(); - phis_to_patch_.clear(); - dominator_ = context()->GetDominatorAnalysis(func, *cfg()); - CollectTargetVars(func); -} - -bool MemPass::IsLiveAfter(uint32_t var_id, uint32_t label) const { - // For now, return very conservative result: true. This will result in - // correct, but possibly usused, phi code to be generated. A subsequent - // DCE pass should eliminate this code. - // TODO(greg-lunarg): Return more accurate information - (void)var_id; - (void)label; - return true; -} - uint32_t MemPass::Type2Undef(uint32_t type_id) { const auto uitr = type2undefs_.find(type_id); if (uitr != type2undefs_.end()) return uitr->second; const uint32_t undefId = TakeNextId(); - std::unique_ptr undef_inst( - new ir::Instruction(context(), SpvOpUndef, type_id, undefId, {})); + std::unique_ptr undef_inst( + new Instruction(context(), SpvOpUndef, type_id, undefId, {})); get_def_use_mgr()->AnalyzeInstDefUse(&*undef_inst); get_module()->AddGlobalValue(std::move(undef_inst)); type2undefs_[type_id] = undefId; return undefId; } -void MemPass::CollectLiveVars(uint32_t block_label, - std::map* live_vars) { - // Walk up the dominator chain starting at |block_label| looking for variables - // defined at each block in the chain. Since we are only interested for the - // most recent value for each live variable, we only add a - // pair to |live_vars| if this is the first time we find the variable in the - // chain. - for (ir::BasicBlock* block = cfg()->block(block_label); block != nullptr; - block = dominator_->ImmediateDominator(block)) { - for (const auto& var_val : block_defs_map_[block->id()]) { - auto live_vars_it = live_vars->find(var_val.first); - if (live_vars_it == live_vars->end()) live_vars->insert(var_val); - } - } -} - -uint32_t MemPass::GetCurrentValue(uint32_t var_id, uint32_t block_label) { - // Walk up the dominator chain starting at |block_label| looking for the - // current value of variable |var_id|. The first block we find containing a - // definition for |var_id| is the one we are interested in. - for (ir::BasicBlock* block = cfg()->block(block_label); block != nullptr; - block = dominator_->ImmediateDominator(block)) { - const auto& block_defs = block_defs_map_[block->id()]; - const auto& var_val_it = block_defs.find(var_id); - if (var_val_it != block_defs.end()) return var_val_it->second; - } - return 0; -} - -bool MemPass::SSABlockInitLoopHeader( - std::list::iterator block_itr) { - bool modified = false; - const uint32_t label = (*block_itr)->id(); - - // Determine the back-edge label. - uint32_t backLabel = 0; - for (uint32_t predLabel : cfg()->preds(label)) - if (visitedBlocks_.find(predLabel) == visitedBlocks_.end()) { - assert(backLabel == 0); - backLabel = predLabel; - break; - } - assert(backLabel != 0); - - // Determine merge block. - auto mergeInst = (*block_itr)->end(); - --mergeInst; - --mergeInst; - uint32_t mergeLabel = - mergeInst->GetSingleWordInOperand(kLoopMergeMergeBlockIdInIdx); - - // Collect all live variables and a default value for each across all - // non-backedge predecesors. Must be ordered map because phis are - // generated based on order and test results will otherwise vary across - // platforms. - std::map liveVars; - for (uint32_t predLabel : cfg()->preds(label)) { - CollectLiveVars(predLabel, &liveVars); - } - - // Add all stored variables in loop. Set their default value id to zero. - for (auto bi = block_itr; (*bi)->id() != mergeLabel; ++bi) { - ir::BasicBlock* bp = *bi; - for (auto ii = bp->begin(); ii != bp->end(); ++ii) { - if (ii->opcode() != SpvOpStore) { - continue; - } - uint32_t varId; - (void)GetPtr(&*ii, &varId); - if (!IsTargetVar(varId)) { - continue; - } - liveVars[varId] = 0; - } - } - // Insert phi for all live variables that require them. All variables - // defined in loop require a phi. Otherwise all variables - // with differing predecessor values require a phi. - auto insertItr = (*block_itr)->begin(); - for (auto var_val : liveVars) { - const uint32_t varId = var_val.first; - if (!IsLiveAfter(varId, label)) { - continue; - } - const uint32_t val0Id = var_val.second; - bool needsPhi = false; - if (val0Id != 0) { - for (uint32_t predLabel : cfg()->preds(label)) { - // Skip back edge predecessor. - if (predLabel == backLabel) continue; - uint32_t current_value = GetCurrentValue(varId, predLabel); - // Missing (undef) values always cause difference with (defined) value - if (current_value == 0) { - needsPhi = true; - break; - } - if (current_value != val0Id) { - needsPhi = true; - break; - } - } - } else { - needsPhi = true; - } - - // If val is the same for all predecessors, enter it in map - if (!needsPhi) { - block_defs_map_[label].insert(var_val); - continue; - } - - // Val differs across predecessors. Add phi op to block and - // add its result id to the map. For back edge predecessor, - // use the variable id. We will patch this after visiting back - // edge predecessor. For predecessors that do not define a value, - // use undef. - modified = true; - std::vector phi_in_operands; - uint32_t typeId = GetPointeeTypeId(get_def_use_mgr()->GetDef(varId)); - for (uint32_t predLabel : cfg()->preds(label)) { - uint32_t valId; - if (predLabel == backLabel) { - valId = varId; - } else { - uint32_t current_value = GetCurrentValue(varId, predLabel); - if (current_value == 0) - valId = Type2Undef(typeId); - else - valId = current_value; - } - phi_in_operands.push_back( - {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}); - phi_in_operands.push_back( - {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {predLabel}}); - } - const uint32_t phiId = TakeNextId(); - std::unique_ptr newPhi(new ir::Instruction( - context(), SpvOpPhi, typeId, phiId, phi_in_operands)); - // The only phis requiring patching are the ones we create. - phis_to_patch_.insert(phiId); - // Only analyze the phi define now; analyze the phi uses after the - // phi backedge predecessor value is patched. - get_def_use_mgr()->AnalyzeInstDef(&*newPhi); - context()->set_instr_block(&*newPhi, *block_itr); - insertItr = insertItr.InsertBefore(std::move(newPhi)); - ++insertItr; - block_defs_map_[label].insert({varId, phiId}); - } - return modified; -} - -bool MemPass::SSABlockInitMultiPred(ir::BasicBlock* block_ptr) { - bool modified = false; - const uint32_t label = block_ptr->id(); - // Collect all live variables and a default value for each across all - // predecesors. Must be ordered map because phis are generated based on - // order and test results will otherwise vary across platforms. - std::map liveVars; - for (uint32_t predLabel : cfg()->preds(label)) { - assert(visitedBlocks_.find(predLabel) != visitedBlocks_.end()); - CollectLiveVars(predLabel, &liveVars); - } - - // For each live variable, look for a difference in values across - // predecessors that would require a phi and insert one. - auto insertItr = block_ptr->begin(); - for (auto var_val : liveVars) { - const uint32_t varId = var_val.first; - if (!IsLiveAfter(varId, label)) continue; - const uint32_t val0Id = var_val.second; - bool differs = false; - for (uint32_t predLabel : cfg()->preds(label)) { - uint32_t current_value = GetCurrentValue(varId, predLabel); - // Missing values cause a difference because we'll need to create an - // undef for that predecessor. - if (current_value == 0) { - differs = true; - break; - } - if (current_value != val0Id) { - differs = true; - break; - } - } - // If val is the same for all predecessors, enter it in map - if (!differs) { - block_defs_map_[label].insert(var_val); - continue; - } - - modified = true; - - // Val differs across predecessors. Add phi op to block and add its result - // id to the map. - std::vector phi_in_operands; - const uint32_t typeId = GetPointeeTypeId(get_def_use_mgr()->GetDef(varId)); - for (uint32_t predLabel : cfg()->preds(label)) { - uint32_t current_value = GetCurrentValue(varId, predLabel); - // If variable not defined on this path, use undef - const uint32_t valId = - (current_value > 0) ? current_value : Type2Undef(typeId); - phi_in_operands.push_back( - {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}); - phi_in_operands.push_back( - {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {predLabel}}); - } - const uint32_t phiId = TakeNextId(); - std::unique_ptr newPhi(new ir::Instruction( - context(), SpvOpPhi, typeId, phiId, phi_in_operands)); - get_def_use_mgr()->AnalyzeInstDefUse(&*newPhi); - context()->set_instr_block(&*newPhi, block_ptr); - insertItr = insertItr.InsertBefore(std::move(newPhi)); - ++insertItr; - block_defs_map_[label].insert({varId, phiId}); - } - return modified; -} - -bool MemPass::SSABlockInit(std::list::iterator block_itr) { - const size_t numPreds = cfg()->preds((*block_itr)->id()).size(); - if (numPreds == 0) return false; - if ((*block_itr)->IsLoopHeader()) - return SSABlockInitLoopHeader(block_itr); - else - return SSABlockInitMultiPred(*block_itr); -} - bool MemPass::IsTargetVar(uint32_t varId) { if (varId == 0) { return false; @@ -508,10 +249,10 @@ bool MemPass::IsTargetVar(uint32_t varId) { if (seen_non_target_vars_.find(varId) != seen_non_target_vars_.end()) return false; if (seen_target_vars_.find(varId) != seen_target_vars_.end()) return true; - const ir::Instruction* varInst = get_def_use_mgr()->GetDef(varId); + const Instruction* varInst = get_def_use_mgr()->GetDef(varId); if (varInst->opcode() != SpvOpVariable) return false; const uint32_t varTypeId = varInst->type_id(); - const ir::Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); + const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) != SpvStorageClassFunction) { seen_non_target_vars_.insert(varId); @@ -519,7 +260,7 @@ bool MemPass::IsTargetVar(uint32_t varId) { } const uint32_t varPteTypeId = varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx); - ir::Instruction* varPteTypeInst = get_def_use_mgr()->GetDef(varPteTypeId); + Instruction* varPteTypeInst = get_def_use_mgr()->GetDef(varPteTypeId); if (!IsTargetType(varPteTypeInst)) { seen_non_target_vars_.insert(varId); return false; @@ -528,130 +269,6 @@ bool MemPass::IsTargetVar(uint32_t varId) { return true; } -void MemPass::PatchPhis(uint32_t header_id, uint32_t back_id) { - ir::BasicBlock* header = cfg()->block(header_id); - auto phiItr = header->begin(); - for (; phiItr->opcode() == SpvOpPhi; ++phiItr) { - // Only patch phis that we created in a loop header. - // There might be other phis unrelated to our optimizations. - if (0 == phis_to_patch_.count(phiItr->result_id())) continue; - - // Find phi operand index for back edge - uint32_t cnt = 0; - uint32_t idx = phiItr->NumInOperands(); - phiItr->ForEachInId([&cnt, &back_id, &idx](uint32_t* iid) { - if (cnt % 2 == 1 && *iid == back_id) idx = cnt - 1; - ++cnt; - }); - assert(idx != phiItr->NumInOperands()); - - // Replace temporary phi operand with variable's value in backedge block - // map. Use undef if variable not in map. - const uint32_t varId = phiItr->GetSingleWordInOperand(idx); - uint32_t current_value = GetCurrentValue(varId, back_id); - uint32_t valId = - (current_value > 0) - ? current_value - : Type2Undef(GetPointeeTypeId(get_def_use_mgr()->GetDef(varId))); - phiItr->SetInOperand(idx, {valId}); - // Analyze uses now that they are complete - get_def_use_mgr()->AnalyzeInstUse(&*phiItr); - } -} - -bool MemPass::InsertPhiInstructions(ir::Function* func) { - // TODO(dnovillo) the current Phi placement mechanism assumes structured - // control-flow. This should be generalized - // (https://github.com/KhronosGroup/SPIRV-Tools/issues/893). - assert(context()->get_feature_mgr()->HasCapability(SpvCapabilityShader) && - "This only works on structured control flow"); - - bool modified = false; - - // Initialize the data structures used to insert Phi instructions. - InitSSARewrite(func); - - // Process all blocks in structured order. This is just one way (the - // simplest?) to make sure all predecessors blocks are processed before - // a block itself. - std::list structuredOrder; - cfg()->ComputeStructuredOrder(func, cfg()->pseudo_entry_block(), - &structuredOrder); - for (auto bi = structuredOrder.begin(); bi != structuredOrder.end(); ++bi) { - // Skip pseudo entry block - if (cfg()->IsPseudoEntryBlock(*bi)) { - continue; - } - - // Process all stores and loads of targeted variables. - if (SSABlockInit(bi)) { - modified = true; - } - - ir::BasicBlock* bp = *bi; - const uint32_t label = bp->id(); - ir::Instruction* inst = &*bp->begin(); - while (inst) { - ir::Instruction* next_instruction = inst->NextNode(); - switch (inst->opcode()) { - case SpvOpStore: { - uint32_t varId; - (void)GetPtr(inst, &varId); - if (!IsTargetVar(varId)) break; - // Register new stored value for the variable - block_defs_map_[label][varId] = - inst->GetSingleWordInOperand(kStoreValIdInIdx); - } break; - case SpvOpVariable: { - // Treat initialized OpVariable like an OpStore - if (inst->NumInOperands() < 2) break; - uint32_t varId = inst->result_id(); - if (!IsTargetVar(varId)) break; - // Register new stored value for the variable - block_defs_map_[label][varId] = - inst->GetSingleWordInOperand(kVariableInitIdInIdx); - } break; - case SpvOpLoad: { - uint32_t varId; - (void)GetPtr(inst, &varId); - if (!IsTargetVar(varId)) break; - modified = true; - uint32_t replId = GetCurrentValue(varId, label); - // If the variable is not defined, use undef. - if (replId == 0) { - replId = - Type2Undef(GetPointeeTypeId(get_def_use_mgr()->GetDef(varId))); - } - - // Replace load's id with the last stored value id for variable - // and delete load. Kill any names or decorates using id before - // replacing to prevent incorrect replacement in those instructions. - const uint32_t loadId = inst->result_id(); - context()->KillNamesAndDecorates(loadId); - (void)context()->ReplaceAllUsesWith(loadId, replId); - context()->KillInst(inst); - } break; - default: - break; - } - inst = next_instruction; - } - visitedBlocks_.insert(label); - // Look for successor backedge and patch phis in loop header - // if found. - uint32_t header = 0; - const auto* const_bp = bp; - const_bp->ForEachSuccessorLabel([&header, this](uint32_t succ) { - if (visitedBlocks_.find(succ) == visitedBlocks_.end()) return; - assert(header == 0); - header = succ; - }); - if (header != 0) PatchPhis(header, label); - } - - return modified; -} - // Remove all |phi| operands coming from unreachable blocks (i.e., blocks not in // |reachable_blocks|). There are two types of removal that this function can // perform: @@ -697,9 +314,8 @@ bool MemPass::InsertPhiInstructions(ir::Function* func) { // [ ... ] // %30 = OpPhi %int %int_42 %13 %50 %14 %50 %15 void MemPass::RemovePhiOperands( - ir::Instruction* phi, - std::unordered_set reachable_blocks) { - std::vector keep_operands; + Instruction* phi, const std::unordered_set& reachable_blocks) { + std::vector keep_operands; uint32_t type_id = 0; // The id of an undefined value we've generated. uint32_t undef_id = 0; @@ -719,7 +335,7 @@ void MemPass::RemovePhiOperands( assert(i % 2 == 0 && i < phi->NumOperands() - 1 && "malformed Phi arguments"); - ir::BasicBlock* in_block = cfg()->block(phi->GetSingleWordOperand(i + 1)); + BasicBlock* in_block = cfg()->block(phi->GetSingleWordOperand(i + 1)); if (reachable_blocks.find(in_block) == reachable_blocks.end()) { // If the incoming block is unreachable, remove both operands as this // means that the |phi| has lost an incoming edge. @@ -729,8 +345,8 @@ void MemPass::RemovePhiOperands( // In all other cases, the operand must be kept but may need to be changed. uint32_t arg_id = phi->GetSingleWordOperand(i); - ir::Instruction* arg_def_instr = get_def_use_mgr()->GetDef(arg_id); - ir::BasicBlock* def_block = context()->get_instr_block(arg_def_instr); + Instruction* arg_def_instr = get_def_use_mgr()->GetDef(arg_id); + BasicBlock* def_block = context()->get_instr_block(arg_def_instr); if (def_block && reachable_blocks.find(def_block) == reachable_blocks.end()) { // If the current |phi| argument was defined in an unreachable block, it @@ -741,7 +357,7 @@ void MemPass::RemovePhiOperands( undef_id = Type2Undef(type_id); } keep_operands.push_back( - ir::Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {undef_id})); + Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {undef_id})); } else { // Otherwise, the argument comes from a reachable block or from no block // at all (meaning that it was defined in the global section of the @@ -759,11 +375,11 @@ void MemPass::RemovePhiOperands( context()->AnalyzeUses(phi); } -void MemPass::RemoveBlock(ir::Function::iterator* bi) { +void MemPass::RemoveBlock(Function::iterator* bi) { auto& rm_block = **bi; // Remove instructions from the block. - rm_block.ForEachInst([&rm_block, this](ir::Instruction* inst) { + rm_block.ForEachInst([&rm_block, this](Instruction* inst) { // Note that we do not kill the block label instruction here. The label // instruction is needed to identify the block, which is needed by the // removal of phi operands. @@ -779,13 +395,13 @@ void MemPass::RemoveBlock(ir::Function::iterator* bi) { *bi = bi->Erase(); } -bool MemPass::RemoveUnreachableBlocks(ir::Function* func) { +bool MemPass::RemoveUnreachableBlocks(Function* func) { bool modified = false; // Mark reachable all blocks reachable from the function's entry block. - std::unordered_set reachable_blocks; - std::unordered_set visited_blocks; - std::queue worklist; + std::unordered_set reachable_blocks; + std::unordered_set visited_blocks; + std::queue worklist; reachable_blocks.insert(func->entry().get()); // Initially mark the function entry point as reachable. @@ -803,11 +419,11 @@ bool MemPass::RemoveUnreachableBlocks(ir::Function* func) { // Transitively mark all blocks reachable from the entry as reachable. while (!worklist.empty()) { - ir::BasicBlock* block = worklist.front(); + BasicBlock* block = worklist.front(); worklist.pop(); // All the successors of a live block are also live. - static_cast(block)->ForEachSuccessorLabel( + static_cast(block)->ForEachSuccessorLabel( mark_reachable); // All the Merge and ContinueTarget blocks of a live block are also live. @@ -825,7 +441,7 @@ bool MemPass::RemoveUnreachableBlocks(ir::Function* func) { // If the block is reachable and has Phi instructions, remove all // operands from its Phi instructions that reference unreachable blocks. // If the block has no Phi instructions, this is a no-op. - block.ForEachPhiInst([&reachable_blocks, this](ir::Instruction* phi) { + block.ForEachPhiInst([&reachable_blocks, this](Instruction* phi) { RemovePhiOperands(phi, reachable_blocks); }); } @@ -843,16 +459,15 @@ bool MemPass::RemoveUnreachableBlocks(ir::Function* func) { return modified; } -bool MemPass::CFGCleanup(ir::Function* func) { +bool MemPass::CFGCleanup(Function* func) { bool modified = false; modified |= RemoveUnreachableBlocks(func); return modified; } -void MemPass::CollectTargetVars(ir::Function* func) { +void MemPass::CollectTargetVars(Function* func) { seen_target_vars_.clear(); seen_non_target_vars_.clear(); - supported_ref_vars_.clear(); type2undefs_.clear(); // Collect target (and non-) variable sets. Remove variables with diff --git a/3rdparty/spirv-tools/source/opt/mem_pass.h b/3rdparty/spirv-tools/source/opt/mem_pass.h index f3952c3a5..67ce26b13 100644 --- a/3rdparty/spirv-tools/source/opt/mem_pass.h +++ b/3rdparty/spirv-tools/source/opt/mem_pass.h @@ -14,8 +14,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_MEM_PASS_H_ -#define LIBSPIRV_OPT_MEM_PASS_H_ +#ifndef SOURCE_OPT_MEM_PASS_H_ +#define SOURCE_OPT_MEM_PASS_H_ #include #include @@ -25,11 +25,11 @@ #include #include -#include "basic_block.h" -#include "def_use_manager.h" -#include "dominator_analysis.h" -#include "module.h" -#include "pass.h" +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -38,7 +38,6 @@ namespace opt { // utility functions and supporting state. class MemPass : public Pass { public: - MemPass(); virtual ~MemPass() = default; // Returns an undef value for the given |var_id|'s type. @@ -49,7 +48,7 @@ class MemPass : public Pass { // Given a load or store |ip|, return the pointer instruction. // Also return the base variable's id in |varId|. If no base variable is // found, |varId| will be 0. - ir::Instruction* GetPtr(ir::Instruction* ip, uint32_t* varId); + Instruction* GetPtr(Instruction* ip, uint32_t* varId); // Return true if |varId| is a previously identified target variable. // Return false if |varId| is a previously identified non-target variable. @@ -65,19 +64,20 @@ class MemPass : public Pass { // Collect target SSA variables. This traverses all the loads and stores in // function |func| looking for variables that can be replaced with SSA IDs. It - // populates the sets |seen_target_vars_|, |seen_non_target_vars_| and - // |supported_ref_vars_|. - void CollectTargetVars(ir::Function* func); + // populates the sets |seen_target_vars_| and |seen_non_target_vars_|. + void CollectTargetVars(Function* func); protected: + MemPass(); + // Returns true if |typeInst| is a scalar type // or a vector or matrix - bool IsBaseTargetType(const ir::Instruction* typeInst) const; + bool IsBaseTargetType(const Instruction* typeInst) const; // Returns true if |typeInst| is a math type or a struct or array // of a math type. // TODO(): Add more complex types to convert - bool IsTargetType(const ir::Instruction* typeInst) const; + bool IsTargetType(const Instruction* typeInst) const; // Returns true if |opcode| is a non-ptr access chain op bool IsNonPtrAccessChain(const SpvOp opcode) const; @@ -89,14 +89,14 @@ class MemPass : public Pass { // Given the id of a pointer |ptrId|, return the top-most non-CopyObj. // Also return the base variable's id in |varId|. If no base variable is // found, |varId| will be 0. - ir::Instruction* GetPtr(uint32_t ptrId, uint32_t* varId); + Instruction* GetPtr(uint32_t ptrId, uint32_t* varId); // Return true if all uses of |id| are only name or decorate ops. bool HasOnlyNamesAndDecorates(uint32_t id) const; // Kill all instructions in block |bp|. Whether or not to kill the label is // indicated by |killLabel|. - void KillAllInsts(ir::BasicBlock* bp, bool killLabel = true); + void KillAllInsts(BasicBlock* bp, bool killLabel = true); // Return true if any instruction loads from |varId| bool HasLoads(uint32_t varId) const; @@ -105,21 +105,16 @@ class MemPass : public Pass { // a load bool IsLiveVar(uint32_t varId) const; - // Return true if |storeInst| is not a function variable or if its - // base variable has a load - bool IsLiveStore(ir::Instruction* storeInst); - // Add stores using |ptr_id| to |insts| - void AddStores(uint32_t ptr_id, std::queue* insts); + void AddStores(uint32_t ptr_id, std::queue* insts); // Delete |inst| and iterate DCE on all its operands if they are now // useless. If a load is deleted and its variable has no other loads, // delete all its variable's stores. - void DCEInst(ir::Instruction* inst, - const std::function&); + void DCEInst(Instruction* inst, const std::function&); // Call all the cleanup helper functions on |func|. - bool CFGCleanup(ir::Function* func); + bool CFGCleanup(Function* func); // Return true if |op| is supported decorate. inline bool IsNonTypeDecorate(uint32_t op) const { @@ -131,11 +126,6 @@ class MemPass : public Pass { // undef to function undef map. uint32_t Type2Undef(uint32_t type_id); - // Insert Phi instructions in the CFG of |func|. This removes extra - // load/store operations to local storage while preserving the SSA form of the - // code. Returns true if the code was modified. - bool InsertPhiInstructions(ir::Function* func); - // Cache of verified target vars std::unordered_set seen_target_vars_; @@ -150,101 +140,24 @@ class MemPass : public Pass { // implementation? bool HasOnlySupportedRefs(uint32_t varId); - // Patch phis in loop header block |header_id| now that the map is complete - // for the backedge predecessor |back_id|. Specifically, for each phi, find - // the value corresponding to the backedge predecessor. That was temporarily - // set with the variable id that this phi corresponds to. Change this phi - // operand to the the value which corresponds to that variable in the - // predecessor map. - void PatchPhis(uint32_t header_id, uint32_t back_id); - - // Initialize data structures used by EliminateLocalMultiStore for - // function |func|, specifically block predecessors and target variables. - void InitSSARewrite(ir::Function* func); - - // Initialize block_defs_map_ entry for loop header block pointed to - // |block_itr| by merging entries from all predecessors. If any value - // ids differ for any variable across predecessors, create a phi function - // in the block and use that value id for the variable in the new map. - // Assumes all predecessors have been visited by EliminateLocalMultiStore - // except the back edge. Use a dummy value in the phi for the back edge - // until the back edge block is visited and patch the phi value then. - // Returns true if the code was modified. - bool SSABlockInitLoopHeader(std::list::iterator block_itr); - - // Initialize block_defs_map_ entry for multiple predecessor block - // |block_ptr| by merging block_defs_map_ entries for all predecessors. - // If any value ids differ for any variable across predecessors, create - // a phi function in the block and use that value id for the variable in - // the new map. Assumes all predecessors have been visited by - // EliminateLocalMultiStore. - // Returns true if the code was modified. - bool SSABlockInitMultiPred(ir::BasicBlock* block_ptr); - - // Initialize the label2ssa_map entry for a block pointed to by |block_itr|. - // Insert phi instructions into block when necessary. All predecessor - // blocks must have been visited by EliminateLocalMultiStore except for - // backedges. - // Returns true if the code was modified. - bool SSABlockInit(std::list::iterator block_itr); - - // Return true if variable is loaded in block with |label| or in any - // succeeding block in structured order. - bool IsLiveAfter(uint32_t var_id, uint32_t label) const; - // Remove all the unreachable basic blocks in |func|. - bool RemoveUnreachableBlocks(ir::Function* func); + bool RemoveUnreachableBlocks(Function* func); // Remove the block pointed by the iterator |*bi|. This also removes // all the instructions in the pointed-to block. - void RemoveBlock(ir::Function::iterator* bi); + void RemoveBlock(Function::iterator* bi); // Remove Phi operands in |phi| that are coming from blocks not in // |reachable_blocks|. - void RemovePhiOperands(ir::Instruction* phi, - std::unordered_set reachable_blocks); - - // Collects a map of all the live variables and their values along the path of - // dominator parents starting at |block_label|. Each entry - // |live_vars[var_id]| returns the latest value of |var_id| along that - // dominator path. Note that the mapping |live_vars| is never cleared, - // multiple calls to this function will accumulate new - // mappings. This is done to support the logic in - // MemPass::SSABlockInitLoopHeader. - void CollectLiveVars(uint32_t block_label, - std::map* live_vars); - - // Returns the ID of the most current value taken by variable |var_id| on the - // dominator path starting at |block_id|. This walks the dominator parents - // starting at |block_id| and returns the first value it finds for |var_id|. - // If no value for |var_id| is found along the dominator path, this returns 0. - uint32_t GetCurrentValue(uint32_t var_id, uint32_t block_label); - - // Dominator information. - DominatorAnalysis* dominator_; - - // Each entry |block_defs_map_[block_id]| contains a map {variable_id, - // value_id} associating all the variables |variable_id| stored in |block_id| - // to their respective value |value_id|. - std::unordered_map> - block_defs_map_; - - // Set of label ids of visited blocks - std::unordered_set visitedBlocks_; - - // Variables that are only referenced by supported operations for this - // pass ie. loads and stores. - std::unordered_set supported_ref_vars_; + void RemovePhiOperands( + Instruction* phi, + const std::unordered_set& reachable_blocks); // Map from type to undef std::unordered_map type2undefs_; - - // The Ids of OpPhi instructions that are in a loop header and which require - // patching of the value for the loop back-edge. - std::unordered_set phis_to_patch_; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_MEM_PASS_H_ +#endif // SOURCE_OPT_MEM_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/merge_return_pass.cpp b/3rdparty/spirv-tools/source/opt/merge_return_pass.cpp index 7fafd5830..4068d4a18 100644 --- a/3rdparty/spirv-tools/source/opt/merge_return_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/merge_return_pass.cpp @@ -12,25 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "merge_return_pass.h" +#include "source/opt/merge_return_pass.h" -#include "instruction.h" -#include "ir_builder.h" -#include "ir_context.h" -#include "make_unique.h" -#include "reflect.h" +#include +#include +#include + +#include "source/opt/instruction.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/opt/reflect.h" +#include "source/util/make_unique.h" namespace spvtools { namespace opt { -Pass::Status MergeReturnPass::Process(ir::IRContext* irContext) { - InitializeProcessing(irContext); - +Pass::Status MergeReturnPass::Process() { bool modified = false; bool is_shader = context()->get_feature_mgr()->HasCapability(SpvCapabilityShader); for (auto& function : *get_module()) { - std::vector return_blocks = CollectReturnBlocks(&function); + std::vector return_blocks = CollectReturnBlocks(&function); if (return_blocks.size() <= 1) continue; function_ = &function; @@ -50,8 +52,8 @@ Pass::Status MergeReturnPass::Process(ir::IRContext* irContext) { } void MergeReturnPass::ProcessStructured( - ir::Function* function, const std::vector& return_blocks) { - std::list order; + Function* function, const std::vector& return_blocks) { + std::list order; cfg()->ComputeStructuredOrder(function, &*function->begin(), &order); // Create the new return block @@ -78,15 +80,40 @@ void MergeReturnPass::ProcessStructured( ProcessStructuredBlock(block); // Generate state for next block - if (ir::Instruction* mergeInst = block->GetMergeInst()) { - ir::Instruction* loopMergeInst = block->GetLoopMergeInst(); + if (Instruction* mergeInst = block->GetMergeInst()) { + Instruction* loopMergeInst = block->GetLoopMergeInst(); if (!loopMergeInst) loopMergeInst = state_.back().LoopMergeInst(); state_.emplace_back(loopMergeInst, mergeInst); } } - // Predicate successors of the original return blocks as necessary. - PredicateBlocks(return_blocks); + state_.clear(); + state_.emplace_back(nullptr, nullptr); + std::unordered_set predicated; + for (auto block : order) { + if (cfg()->IsPseudoEntryBlock(block) || cfg()->IsPseudoExitBlock(block)) { + continue; + } + + auto blockId = block->id(); + if (blockId == CurrentState().CurrentMergeId()) { + // Pop the current state as we've hit the merge + state_.pop_back(); + } + + // Predicate successors of the original return blocks as necessary. + if (std::find(return_blocks.begin(), return_blocks.end(), block) != + return_blocks.end()) { + PredicateBlocks(block, &predicated, &order); + } + + // Generate state for next block + if (Instruction* mergeInst = block->GetMergeInst()) { + Instruction* loopMergeInst = block->GetLoopMergeInst(); + if (!loopMergeInst) loopMergeInst = state_.back().LoopMergeInst(); + state_.emplace_back(loopMergeInst, mergeInst); + } + } // We have not kept the dominator tree up-to-date. // Invalidate it at this point to make sure it will be rebuilt. @@ -96,12 +123,12 @@ void MergeReturnPass::ProcessStructured( void MergeReturnPass::CreateReturnBlock() { // Create a label for the new return block - std::unique_ptr return_label( - new ir::Instruction(context(), SpvOpLabel, 0u, TakeNextId(), {})); + std::unique_ptr return_label( + new Instruction(context(), SpvOpLabel, 0u, TakeNextId(), {})); // Create the new basic block - std::unique_ptr return_block( - new ir::BasicBlock(std::move(return_label))); + std::unique_ptr return_block( + new BasicBlock(std::move(return_label))); function_->AddBasicBlock(std::move(return_block)); final_return_block_ = &*(--function_->end()); context()->AnalyzeDefUse(final_return_block_->GetLabelInst()); @@ -110,33 +137,35 @@ void MergeReturnPass::CreateReturnBlock() { final_return_block_->SetParent(function_); } -void MergeReturnPass::CreateReturn(ir::BasicBlock* block) { +void MergeReturnPass::CreateReturn(BasicBlock* block) { AddReturnValue(); if (return_value_) { // Load and return the final return value uint32_t loadId = TakeNextId(); - block->AddInstruction(MakeUnique( + block->AddInstruction(MakeUnique( context(), SpvOpLoad, function_->type_id(), loadId, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_ID, {return_value_->result_id()}}})); - ir::Instruction* var_inst = block->terminator(); + Instruction* var_inst = block->terminator(); context()->AnalyzeDefUse(var_inst); context()->set_instr_block(var_inst, block); + context()->get_decoration_mgr()->CloneDecorations( + return_value_->result_id(), loadId, {SpvDecorationRelaxedPrecision}); - block->AddInstruction(MakeUnique( + block->AddInstruction(MakeUnique( context(), SpvOpReturnValue, 0, 0, - std::initializer_list{{SPV_OPERAND_TYPE_ID, {loadId}}})); + std::initializer_list{{SPV_OPERAND_TYPE_ID, {loadId}}})); context()->AnalyzeDefUse(block->terminator()); context()->set_instr_block(block->terminator(), block); } else { - block->AddInstruction(MakeUnique(context(), SpvOpReturn)); + block->AddInstruction(MakeUnique(context(), SpvOpReturn)); context()->AnalyzeDefUse(block->terminator()); context()->set_instr_block(block->terminator(), block); } } -void MergeReturnPass::ProcessStructuredBlock(ir::BasicBlock* block) { +void MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) { SpvOp tail_opcode = block->tail()->opcode(); if (tail_opcode == SpvOpReturn || tail_opcode == SpvOpReturnValue) { if (!return_flag_) { @@ -157,51 +186,53 @@ void MergeReturnPass::ProcessStructuredBlock(ir::BasicBlock* block) { } } -void MergeReturnPass::BranchToBlock(ir::BasicBlock* block, uint32_t target) { +void MergeReturnPass::BranchToBlock(BasicBlock* block, uint32_t target) { if (block->tail()->opcode() == SpvOpReturn || block->tail()->opcode() == SpvOpReturnValue) { RecordReturned(block); RecordReturnValue(block); } + BasicBlock* target_block = context()->get_instr_block(target); + UpdatePhiNodes(block, target_block); - // Fix up existing phi nodes. - // - // A new edge is being added from |block| to |target|, so go through - // |target|'s phi nodes add an undef incoming value for |block|. - ir::BasicBlock* target_block = context()->get_instr_block(target); - target_block->ForEachPhiInst([this, block](ir::Instruction* inst) { - uint32_t undefId = Type2Undef(inst->type_id()); - inst->AddOperand({SPV_OPERAND_TYPE_ID, {undefId}}); - inst->AddOperand({SPV_OPERAND_TYPE_ID, {block->id()}}); - context()->UpdateDefUse(inst); - }); - - const auto& target_pred = cfg()->preds(target); - if (target_pred.size() == 1) { - MarkForNewPhiNodes(target_block, - context()->get_instr_block(target_pred[0])); - } - - ir::Instruction* return_inst = block->terminator(); + Instruction* return_inst = block->terminator(); return_inst->SetOpcode(SpvOpBranch); return_inst->ReplaceOperands({{SPV_OPERAND_TYPE_ID, {target}}}); context()->get_def_use_mgr()->AnalyzeInstDefUse(return_inst); cfg()->AddEdge(block->id(), target); } -void MergeReturnPass::CreatePhiNodesForInst(ir::BasicBlock* merge_block, +void MergeReturnPass::UpdatePhiNodes(BasicBlock* new_source, + BasicBlock* target) { + target->ForEachPhiInst([this, new_source](Instruction* inst) { + uint32_t undefId = Type2Undef(inst->type_id()); + inst->AddOperand({SPV_OPERAND_TYPE_ID, {undefId}}); + inst->AddOperand({SPV_OPERAND_TYPE_ID, {new_source->id()}}); + context()->UpdateDefUse(inst); + }); + + const auto& target_pred = cfg()->preds(target->id()); + if (target_pred.size() == 1) { + MarkForNewPhiNodes(target, context()->get_instr_block(target_pred[0])); + } +} + +void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block, uint32_t predecessor, - ir::Instruction& inst) { - opt::DominatorAnalysis* dom_tree = - context()->GetDominatorAnalysis(merge_block->GetParent(), *cfg()); - ir::BasicBlock* inst_bb = context()->get_instr_block(&inst); + Instruction& inst) { + DominatorAnalysis* dom_tree = + context()->GetDominatorAnalysis(merge_block->GetParent()); + BasicBlock* inst_bb = context()->get_instr_block(&inst); if (inst.result_id() != 0) { - std::vector users_to_update; + std::vector users_to_update; context()->get_def_use_mgr()->ForEachUser( - &inst, - [&users_to_update, &dom_tree, inst_bb, this](ir::Instruction* user) { - if (!dom_tree->Dominates(inst_bb, context()->get_instr_block(user))) { + &inst, [&users_to_update, &dom_tree, inst_bb, this](Instruction* user) { + BasicBlock* user_bb = context()->get_instr_block(user); + // If |user_bb| is nullptr, then |user| is not in the function. It is + // something like an OpName or decoration, which should not be + // replaced with the result of the OpPhi. + if (user_bb && !dom_tree->Dominates(inst_bb, user_bb)) { users_to_update.push_back(user); } }); @@ -213,7 +244,7 @@ void MergeReturnPass::CreatePhiNodesForInst(ir::BasicBlock* merge_block, // There is at least one values that needs to be replaced. // First create the OpPhi instruction. InstructionBuilder builder(context(), &*merge_block->begin(), - ir::IRContext::kAnalysisDefUse); + IRContext::kAnalysisDefUse); uint32_t undef_id = Type2Undef(inst.type_id()); std::vector phi_operands; @@ -230,11 +261,11 @@ void MergeReturnPass::CreatePhiNodesForInst(ir::BasicBlock* merge_block, } } - ir::Instruction* new_phi = builder.AddPhi(inst.type_id(), phi_operands); + Instruction* new_phi = builder.AddPhi(inst.type_id(), phi_operands); uint32_t result_of_phi = new_phi->result_id(); // Update all of the users to use the result of the new OpPhi. - for (ir::Instruction* user : users_to_update) { + for (Instruction* user : users_to_update) { user->ForEachInId([&inst, result_of_phi](uint32_t* id) { if (*id == inst.result_id()) { *id = result_of_phi; @@ -246,55 +277,89 @@ void MergeReturnPass::CreatePhiNodesForInst(ir::BasicBlock* merge_block, } void MergeReturnPass::PredicateBlocks( - const std::vector& return_blocks) { + BasicBlock* return_block, std::unordered_set* predicated, + std::list* order) { // The CFG is being modified as the function proceeds so avoid caching // successors. - std::vector stack; - auto add_successors = [this, &stack](ir::BasicBlock* block) { - const ir::BasicBlock* const_block = - const_cast(block); - const_block->ForEachSuccessorLabel([this, &stack](const uint32_t idx) { - stack.push_back(context()->get_instr_block(idx)); - }); - }; - std::unordered_set seen; - std::unordered_set predicated; - for (auto b : return_blocks) { - seen.clear(); - add_successors(b); + if (predicated->count(return_block)) { + return; + } - while (!stack.empty()) { - ir::BasicBlock* block = stack.back(); - assert(block); - stack.pop_back(); + BasicBlock* block = nullptr; + const BasicBlock* const_block = const_cast(return_block); + const_block->ForEachSuccessorLabel([this, &block](const uint32_t idx) { + BasicBlock* succ_block = context()->get_instr_block(idx); + assert(block == nullptr); + block = succ_block; + }); + assert(block && + "Return blocks should have returns already replaced by a single " + "unconditional branch."); - if (block == b) continue; - if (block == final_return_block_) continue; - if (!seen.insert(block).second) continue; - if (!predicated.insert(block).second) continue; - - // Skip structured subgraphs. - ir::BasicBlock* next = block; - while (next->GetMergeInst()) { - next = context()->get_instr_block(next->MergeBlockIdIfAny()); - } - add_successors(next); - PredicateBlock(block, next, &predicated); + auto state = state_.rbegin(); + std::unordered_set seen; + if (block->id() == state->CurrentMergeId()) { + state++; + } else if (block->id() == state->LoopMergeId()) { + while (state->LoopMergeId() == block->id()) { + state++; } } + + while (block != nullptr && block != final_return_block_) { + if (!predicated->insert(block).second) break; + // Skip structured subgraphs. + BasicBlock* next = nullptr; + if (state->InLoop()) { + next = context()->get_instr_block(state->LoopMergeId()); + while (state->LoopMergeId() == next->id()) { + state++; + } + BreakFromConstruct(block, next, predicated, order); + } else if (false && state->InStructuredFlow()) { + // TODO(#1861): This is disabled until drivers are fixed to accept + // conditional exits from a selection construct. Reenable tests when + // this code is turned back on. + + next = context()->get_instr_block(state->CurrentMergeId()); + state++; + BreakFromConstruct(block, next, predicated, order); + } else { + BasicBlock* tail = block; + while (tail->GetMergeInst()) { + tail = context()->get_instr_block(tail->MergeBlockIdIfAny()); + } + + // Must find |next| (the successor of |tail|) before predicating the + // block because, if |block| == |tail|, then |tail| will have multiple + // successors. + next = nullptr; + const_cast(tail)->ForEachSuccessorLabel( + [this, &next](const uint32_t idx) { + BasicBlock* succ_block = context()->get_instr_block(idx); + assert(next == nullptr && + "Found block with multiple successors and no merge " + "instruction."); + next = succ_block; + }); + + PredicateBlock(block, tail, predicated, order); + } + block = next; + } } -bool MergeReturnPass::RequiresPredication( - const ir::BasicBlock* block, const ir::BasicBlock* tail_block) const { +bool MergeReturnPass::RequiresPredication(const BasicBlock* block, + const BasicBlock* tail_block) const { // This is intentionally conservative. // TODO(alanbaker): re-visit this when more performance data is available. if (block != tail_block) return true; bool requires_predicate = false; - block->ForEachInst([&requires_predicate](const ir::Instruction* inst) { + block->ForEachInst([&requires_predicate](const Instruction* inst) { if (inst->opcode() != SpvOpPhi && inst->opcode() != SpvOpLabel && - !ir::IsTerminatorInst(inst->opcode())) { + !IsTerminatorInst(inst->opcode())) { requires_predicate = true; } }); @@ -302,26 +367,27 @@ bool MergeReturnPass::RequiresPredication( } void MergeReturnPass::PredicateBlock( - ir::BasicBlock* block, ir::BasicBlock* tail_block, - std::unordered_set* predicated) { + BasicBlock* block, BasicBlock* tail_block, + std::unordered_set* predicated, + std::list* order) { if (!RequiresPredication(block, tail_block)) { return; } - // Make sure the cfg is build here. If we don't then it becomes very hard to - // know which new blocks need to be updated. - context()->BuildInvalidAnalyses(ir::IRContext::kAnalysisCFG); + // Make sure the cfg is build here. If we don't then it becomes very hard + // to know which new blocks need to be updated. + context()->BuildInvalidAnalyses(IRContext::kAnalysisCFG); - // When predicating, be aware of whether this block is a header block, a merge - // block or both. + // When predicating, be aware of whether this block is a header block, a + // merge block or both. // // If this block is a merge block, ensure the appropriate header stays // up-to-date with any changes (i.e. points to the pre-header). // - // If this block is a header block, predicate the entire structured subgraph. - // This can act recursively. + // If this block is a header block, predicate the entire structured + // subgraph. This can act recursively. - // If |block| is a loop head, then the back edge must jump to the original + // If |block| is a loop header, then the back edge must jump to the original // code, not the new header. if (block->GetLoopMergeInst()) { cfg()->SplitLoopHeader(block); @@ -336,46 +402,49 @@ void MergeReturnPass::PredicateBlock( // Forget about the edges leaving block. They will be removed. cfg()->RemoveSuccessorEdges(block); - std::unique_ptr new_block( + std::unique_ptr new_block( block->SplitBasicBlock(context(), TakeNextId(), iter)); - ir::BasicBlock* old_body = + BasicBlock* old_body = function_->InsertBasicBlockAfter(std::move(new_block), block); predicated->insert(old_body); + // Update |order| so old_block will be traversed. + InsertAfterElement(block, old_body, order); + if (tail_block == block) { tail_block = old_body; } - const ir::BasicBlock* const_old_body = - static_cast(old_body); + const BasicBlock* const_old_body = static_cast(old_body); const_old_body->ForEachSuccessorLabel( [old_body, block, this](const uint32_t label) { - ir::BasicBlock* target_bb = context()->get_instr_block(label); + BasicBlock* target_bb = context()->get_instr_block(label); if (MarkedSinglePred(target_bb) == block) { MarkForNewPhiNodes(target_bb, old_body); } }); - std::unique_ptr new_merge_block(new ir::BasicBlock( - MakeUnique(context(), SpvOpLabel, 0, TakeNextId(), - std::initializer_list{}))); + std::unique_ptr new_merge_block(new BasicBlock( + MakeUnique(context(), SpvOpLabel, 0, TakeNextId(), + std::initializer_list{}))); - ir::BasicBlock* new_merge = + BasicBlock* new_merge = function_->InsertBasicBlockAfter(std::move(new_merge_block), tail_block); predicated->insert(new_merge); new_merge->SetParent(function_); - // Register the new labels. - get_def_use_mgr()->AnalyzeInstDef(old_body->GetLabelInst()); - context()->set_instr_block(old_body->GetLabelInst(), old_body); + // Update |order| so old_block will be traversed. + InsertAfterElement(tail_block, new_merge, order); + + // Register the new label. get_def_use_mgr()->AnalyzeInstDef(new_merge->GetLabelInst()); context()->set_instr_block(new_merge->GetLabelInst(), new_merge); // Move the tail branch into the new merge and fix the mapping. If a single // block is being predicated then its branch was moved to the old body // previously. - std::unique_ptr inst; - ir::Instruction* i = tail_block->terminator(); + std::unique_ptr inst; + Instruction* i = tail_block->terminator(); cfg()->RemoveSuccessorEdges(tail_block); get_def_use_mgr()->ClearInst(i); inst.reset(std::move(i)); @@ -384,12 +453,12 @@ void MergeReturnPass::PredicateBlock( get_def_use_mgr()->AnalyzeInstUse(new_merge->terminator()); context()->set_instr_block(new_merge->terminator(), new_merge); - // Add a branch to the new merge. If we jumped multiple blocks, the branch is - // added to tail_block, otherwise the branch belongs in old_body. - tail_block->AddInstruction(MakeUnique( - context(), SpvOpBranch, 0, 0, - std::initializer_list{ - {SPV_OPERAND_TYPE_ID, {new_merge->id()}}})); + // Add a branch to the new merge. If we jumped multiple blocks, the branch + // is added to tail_block, otherwise the branch belongs in old_body. + tail_block->AddInstruction( + MakeUnique(context(), SpvOpBranch, 0, 0, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {new_merge->id()}}})); get_def_use_mgr()->AnalyzeInstUse(tail_block->terminator()); context()->set_instr_block(tail_block->terminator(), tail_block); @@ -403,30 +472,28 @@ void MergeReturnPass::PredicateBlock( uint32_t bool_id = context()->get_type_mgr()->GetId(&bool_type); assert(bool_id != 0); uint32_t load_id = TakeNextId(); - block->AddInstruction(MakeUnique( + block->AddInstruction(MakeUnique( context(), SpvOpLoad, bool_id, load_id, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_ID, {return_flag_->result_id()}}})); get_def_use_mgr()->AnalyzeInstDefUse(block->terminator()); context()->set_instr_block(block->terminator(), block); // 2. Declare the merge block - block->AddInstruction( - MakeUnique(context(), SpvOpSelectionMerge, 0, 0, - std::initializer_list{ - {SPV_OPERAND_TYPE_ID, {new_merge->id()}}, - {SPV_OPERAND_TYPE_SELECTION_CONTROL, - {SpvSelectionControlMaskNone}}})); + block->AddInstruction(MakeUnique( + context(), SpvOpSelectionMerge, 0, 0, + std::initializer_list{{SPV_OPERAND_TYPE_ID, {new_merge->id()}}, + {SPV_OPERAND_TYPE_SELECTION_CONTROL, + {SpvSelectionControlMaskNone}}})); get_def_use_mgr()->AnalyzeInstUse(block->terminator()); context()->set_instr_block(block->terminator(), block); // 3. Branch to new merge (true) or old body (false) - block->AddInstruction(MakeUnique( + block->AddInstruction(MakeUnique( context(), SpvOpBranchConditional, 0, 0, - std::initializer_list{ - {SPV_OPERAND_TYPE_ID, {load_id}}, - {SPV_OPERAND_TYPE_ID, {new_merge->id()}}, - {SPV_OPERAND_TYPE_ID, {old_body->id()}}})); + std::initializer_list{{SPV_OPERAND_TYPE_ID, {load_id}}, + {SPV_OPERAND_TYPE_ID, {new_merge->id()}}, + {SPV_OPERAND_TYPE_ID, {old_body->id()}}})); get_def_use_mgr()->AnalyzeInstUse(block->terminator()); context()->set_instr_block(block->terminator(), block); @@ -444,7 +511,93 @@ void MergeReturnPass::PredicateBlock( MarkForNewPhiNodes(new_merge, tail_block); } -void MergeReturnPass::RecordReturned(ir::BasicBlock* block) { +void MergeReturnPass::BreakFromConstruct( + BasicBlock* block, BasicBlock* merge_block, + std::unordered_set* predicated, + std::list* order) { + // Make sure the cfg is build here. If we don't then it becomes very hard + // to know which new blocks need to be updated. + context()->BuildInvalidAnalyses(IRContext::kAnalysisCFG); + + // When predicating, be aware of whether this block is a header block, a + // merge block or both. + // + // If this block is a merge block, ensure the appropriate header stays + // up-to-date with any changes (i.e. points to the pre-header). + // + // If this block is a header block, predicate the entire structured + // subgraph. This can act recursively. + + // If |block| is a loop header, then the back edge must jump to the original + // code, not the new header. + if (block->GetLoopMergeInst()) { + cfg()->SplitLoopHeader(block); + } + + // Leave the phi instructions behind. + auto iter = block->begin(); + while (iter->opcode() == SpvOpPhi) { + ++iter; + } + + // Forget about the edges leaving block. They will be removed. + cfg()->RemoveSuccessorEdges(block); + + std::unique_ptr new_block( + block->SplitBasicBlock(context(), TakeNextId(), iter)); + BasicBlock* old_body = + function_->InsertBasicBlockAfter(std::move(new_block), block); + predicated->insert(old_body); + + // Update |order| so old_block will be traversed. + InsertAfterElement(block, old_body, order); + + // Within the new header we need the following: + // 1. Load of the return status flag + // 2. Branch to |merge_block| (true) or old body (false) + // 3. Update OpPhi instructions in |merge_block|. + // + // Sine we are branching to the merge block of the current construct, there is + // no need for an OpSelectionMerge. + + // 1. Load of the return status flag + analysis::Bool bool_type; + uint32_t bool_id = context()->get_type_mgr()->GetId(&bool_type); + assert(bool_id != 0); + uint32_t load_id = TakeNextId(); + block->AddInstruction(MakeUnique( + context(), SpvOpLoad, bool_id, load_id, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {return_flag_->result_id()}}})); + get_def_use_mgr()->AnalyzeInstDefUse(block->terminator()); + context()->set_instr_block(block->terminator(), block); + + // 2. Branch to |merge_block| (true) or |old_body| (false) + block->AddInstruction(MakeUnique( + context(), SpvOpBranchConditional, 0, 0, + std::initializer_list{{SPV_OPERAND_TYPE_ID, {load_id}}, + {SPV_OPERAND_TYPE_ID, {merge_block->id()}}, + {SPV_OPERAND_TYPE_ID, {old_body->id()}}})); + get_def_use_mgr()->AnalyzeInstUse(block->terminator()); + context()->set_instr_block(block->terminator(), block); + + // Update the cfg + cfg()->AddEdges(block); + cfg()->RegisterBlock(old_body); + + // 3. Update OpPhi instructions in |merge_block|. + BasicBlock* merge_original_pred = MarkedSinglePred(merge_block); + if (merge_original_pred == nullptr) { + UpdatePhiNodes(block, merge_block); + } else if (merge_original_pred == block) { + MarkForNewPhiNodes(merge_block, old_body); + } + + assert(old_body->begin() != old_body->end()); + assert(block->begin() != block->end()); +} + +void MergeReturnPass::RecordReturned(BasicBlock* block) { if (block->tail()->opcode() != SpvOpReturn && block->tail()->opcode() != SpvOpReturnValue) return; @@ -463,19 +616,19 @@ void MergeReturnPass::RecordReturned(ir::BasicBlock* block) { context()->UpdateDefUse(constant_true_); } - std::unique_ptr return_store(new ir::Instruction( + std::unique_ptr return_store(new Instruction( context(), SpvOpStore, 0, 0, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_ID, {return_flag_->result_id()}}, {SPV_OPERAND_TYPE_ID, {constant_true_->result_id()}}})); - ir::Instruction* store_inst = + Instruction* store_inst = &*block->tail().InsertBefore(std::move(return_store)); context()->set_instr_block(store_inst, block); context()->AnalyzeDefUse(store_inst); } -void MergeReturnPass::RecordReturnValue(ir::BasicBlock* block) { +void MergeReturnPass::RecordReturnValue(BasicBlock* block) { auto terminator = *block->tail(); if (terminator.opcode() != SpvOpReturnValue) { return; @@ -484,13 +637,13 @@ void MergeReturnPass::RecordReturnValue(ir::BasicBlock* block) { assert(return_value_ && "Did not generate the variable to hold the return value."); - std::unique_ptr value_store(new ir::Instruction( + std::unique_ptr value_store(new Instruction( context(), SpvOpStore, 0, 0, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_ID, {return_value_->result_id()}}, {SPV_OPERAND_TYPE_ID, {terminator.GetSingleWordInOperand(0u)}}})); - ir::Instruction* store_inst = + Instruction* store_inst = &*block->tail().InsertBefore(std::move(value_store)); context()->set_instr_block(store_inst, block); context()->AnalyzeDefUse(store_inst); @@ -507,17 +660,20 @@ void MergeReturnPass::AddReturnValue() { return_type_id, SpvStorageClassFunction); uint32_t var_id = TakeNextId(); - std::unique_ptr returnValue(new ir::Instruction( + std::unique_ptr returnValue(new Instruction( context(), SpvOpVariable, return_ptr_type, var_id, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}})); auto insert_iter = function_->begin()->begin(); insert_iter.InsertBefore(std::move(returnValue)); - ir::BasicBlock* entry_block = &*function_->begin(); + BasicBlock* entry_block = &*function_->begin(); return_value_ = &*entry_block->begin(); context()->AnalyzeDefUse(return_value_); context()->set_instr_block(return_value_, entry_block); + + context()->get_decoration_mgr()->CloneDecorations( + function_->result_id(), var_id, {SpvDecorationRelaxedPrecision}); } void MergeReturnPass::AddReturnFlag() { @@ -539,26 +695,26 @@ void MergeReturnPass::AddReturnFlag() { type_mgr->FindPointerToType(bool_id, SpvStorageClassFunction); uint32_t var_id = TakeNextId(); - std::unique_ptr returnFlag(new ir::Instruction( + std::unique_ptr returnFlag(new Instruction( context(), SpvOpVariable, bool_ptr_id, var_id, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}, {SPV_OPERAND_TYPE_ID, {const_false_id}}})); auto insert_iter = function_->begin()->begin(); insert_iter.InsertBefore(std::move(returnFlag)); - ir::BasicBlock* entry_block = &*function_->begin(); + BasicBlock* entry_block = &*function_->begin(); return_flag_ = &*entry_block->begin(); context()->AnalyzeDefUse(return_flag_); context()->set_instr_block(return_flag_, entry_block); } -std::vector MergeReturnPass::CollectReturnBlocks( - ir::Function* function) { - std::vector return_blocks; +std::vector MergeReturnPass::CollectReturnBlocks( + Function* function) { + std::vector return_blocks; for (auto& block : *function) { - ir::Instruction& terminator = *block.tail(); + Instruction& terminator = *block.tail(); if (terminator.opcode() == SpvOpReturn || terminator.opcode() == SpvOpReturnValue) { return_blocks.push_back(&block); @@ -568,7 +724,7 @@ std::vector MergeReturnPass::CollectReturnBlocks( } void MergeReturnPass::MergeReturnBlocks( - ir::Function* function, const std::vector& return_blocks) { + Function* function, const std::vector& return_blocks) { if (return_blocks.size() <= 1) { // No work to do. return; @@ -579,7 +735,7 @@ void MergeReturnPass::MergeReturnBlocks( auto ret_block_iter = --function->end(); // Create the PHI for the merged block (if necessary). // Create new return. - std::vector phi_ops; + std::vector phi_ops; for (auto block : return_blocks) { if (block->tail()->opcode() == SpvOpReturnValue) { phi_ops.push_back( @@ -592,23 +748,23 @@ void MergeReturnPass::MergeReturnBlocks( // Need a PHI node to select the correct return value. uint32_t phi_result_id = TakeNextId(); uint32_t phi_type_id = function->type_id(); - std::unique_ptr phi_inst(new ir::Instruction( + std::unique_ptr phi_inst(new Instruction( context(), SpvOpPhi, phi_type_id, phi_result_id, phi_ops)); ret_block_iter->AddInstruction(std::move(phi_inst)); - ir::BasicBlock::iterator phiIter = ret_block_iter->tail(); + BasicBlock::iterator phiIter = ret_block_iter->tail(); - std::unique_ptr return_inst( - new ir::Instruction(context(), SpvOpReturnValue, 0u, 0u, - {{SPV_OPERAND_TYPE_ID, {phi_result_id}}})); + std::unique_ptr return_inst( + new Instruction(context(), SpvOpReturnValue, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {phi_result_id}}})); ret_block_iter->AddInstruction(std::move(return_inst)); - ir::BasicBlock::iterator ret = ret_block_iter->tail(); + BasicBlock::iterator ret = ret_block_iter->tail(); // Register the phi def and mark instructions for use updates. get_def_use_mgr()->AnalyzeInstDefUse(&*phiIter); get_def_use_mgr()->AnalyzeInstDef(&*ret); } else { - std::unique_ptr return_inst( - new ir::Instruction(context(), SpvOpReturn)); + std::unique_ptr return_inst( + new Instruction(context(), SpvOpReturn)); ret_block_iter->AddInstruction(std::move(return_inst)); } @@ -625,37 +781,44 @@ void MergeReturnPass::MergeReturnBlocks( } void MergeReturnPass::AddNewPhiNodes() { - opt::DominatorAnalysis* dom_tree = - context()->GetDominatorAnalysis(function_, *cfg()); - std::list order; + DominatorAnalysis* dom_tree = context()->GetDominatorAnalysis(function_); + std::list order; cfg()->ComputeStructuredOrder(function_, &*function_->begin(), &order); - for (ir::BasicBlock* bb : order) { + for (BasicBlock* bb : order) { AddNewPhiNodes(bb, new_merge_nodes_[bb], dom_tree->ImmediateDominator(bb)->id()); } } -void MergeReturnPass::AddNewPhiNodes(ir::BasicBlock* bb, ir::BasicBlock* pred, +void MergeReturnPass::AddNewPhiNodes(BasicBlock* bb, BasicBlock* pred, uint32_t header_id) { - opt::DominatorAnalysis* dom_tree = - context()->GetDominatorAnalysis(function_, *cfg()); - // Insert as a stopping point. We do not have to add anything in the block or - // above because the header dominates |bb|. + DominatorAnalysis* dom_tree = context()->GetDominatorAnalysis(function_); + // Insert as a stopping point. We do not have to add anything in the block + // or above because the header dominates |bb|. - ir::BasicBlock* current_bb = pred; + BasicBlock* current_bb = pred; while (current_bb != nullptr && current_bb->id() != header_id) { - for (ir::Instruction& inst : *current_bb) { + for (Instruction& inst : *current_bb) { CreatePhiNodesForInst(bb, pred->id(), inst); } current_bb = dom_tree->ImmediateDominator(current_bb); } } -void MergeReturnPass::MarkForNewPhiNodes(ir::BasicBlock* block, - ir::BasicBlock* single_original_pred) { +void MergeReturnPass::MarkForNewPhiNodes(BasicBlock* block, + BasicBlock* single_original_pred) { new_merge_nodes_[block] = single_original_pred; } +void MergeReturnPass::InsertAfterElement(BasicBlock* element, + BasicBlock* new_element, + std::list* list) { + auto pos = std::find(list->begin(), list->end(), element); + assert(pos != list->end()); + ++pos; + list->insert(pos, new_element); +} + } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/merge_return_pass.h b/3rdparty/spirv-tools/source/opt/merge_return_pass.h index b4f47e3ad..472d059fe 100644 --- a/3rdparty/spirv-tools/source/opt/merge_return_pass.h +++ b/3rdparty/spirv-tools/source/opt/merge_return_pass.h @@ -12,16 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_MERGE_RETURN_PASS_H_ -#define LIBSPIRV_OPT_MERGE_RETURN_PASS_H_ - -#include "basic_block.h" -#include "function.h" -#include "mem_pass.h" +#ifndef SOURCE_OPT_MERGE_RETURN_PASS_H_ +#define SOURCE_OPT_MERGE_RETURN_PASS_H_ +#include #include #include +#include "source/opt/basic_block.h" +#include "source/opt/function.h" +#include "source/opt/mem_pass.h" + namespace spvtools { namespace opt { @@ -103,11 +104,11 @@ class MergeReturnPass : public MemPass { final_return_block_(nullptr) {} const char* name() const override { return "merge-return"; } - Status Process(ir::IRContext*) override; + Status Process() override; - ir::IRContext::Analysis GetPreservedAnalyses() override { - // return ir::IRContext::kAnalysisDefUse; - return ir::IRContext::kAnalysisNone; + IRContext::Analysis GetPreservedAnalyses() override { + // return IRContext::kAnalysisDefUse; + return IRContext::kAnalysisNone; } private: @@ -116,7 +117,7 @@ class MergeReturnPass : public MemPass { // contain selection construct and the inner most loop construct. class StructuredControlState { public: - StructuredControlState(ir::Instruction* loop, ir::Instruction* merge) + StructuredControlState(Instruction* loop, Instruction* merge) : loop_merge_(loop), current_merge_(merge) {} StructuredControlState(const StructuredControlState&) = default; @@ -145,29 +146,29 @@ class MergeReturnPass : public MemPass { : 0; } - ir::Instruction* LoopMergeInst() const { return loop_merge_; } + Instruction* LoopMergeInst() const { return loop_merge_; } private: - ir::Instruction* loop_merge_; - ir::Instruction* current_merge_; + Instruction* loop_merge_; + Instruction* current_merge_; }; // Returns all BasicBlocks terminated by OpReturn or OpReturnValue in // |function|. - std::vector CollectReturnBlocks(ir::Function* function); + std::vector CollectReturnBlocks(Function* function); // Creates a new basic block with a single return. If |function| returns a // value, a phi node is created to select the correct value to return. // Replaces old returns with an unconditional branch to the new block. - void MergeReturnBlocks(ir::Function* function, - const std::vector& returnBlocks); + void MergeReturnBlocks(Function* function, + const std::vector& returnBlocks); // Merges the return instruction in |function| so that it has a single return // statement. It is assumed that |function| has structured control flow, and // that |return_blocks| is a list of all of the basic blocks in |function| // that have a return. - void ProcessStructured(ir::Function* function, - const std::vector& return_blocks); + void ProcessStructured(Function* function, + const std::vector& return_blocks); // Changes an OpReturn* or OpUnreachable instruction at the end of |block| // into a store to |return_flag_|, a store to |return_value_| (if necessary), @@ -178,7 +179,7 @@ class MergeReturnPass : public MemPass { // // Note this will break the semantics. To fix this, PredicateBlock will have // to be called on the merge block the branch targets. - void ProcessStructuredBlock(ir::BasicBlock* block); + void ProcessStructuredBlock(BasicBlock* block); // Creates a variable used to store whether or not the control flow has // traversed a block that used to have a return. A pointer to the instruction @@ -192,7 +193,7 @@ class MergeReturnPass : public MemPass { // Adds a store that stores true to |return_flag_| immediately before the // terminator of |block|. It is assumed that |AddReturnFlag| has already been // called. - void RecordReturned(ir::BasicBlock* block); + void RecordReturned(BasicBlock* block); // Adds an instruction that stores the value being returned in the // OpReturnValue in |block|. The value is stored to |return_value_|, and the @@ -201,35 +202,53 @@ class MergeReturnPass : public MemPass { // If |block| does not contain an OpReturnValue, then this function has no // effect. If |block| contains an OpReturnValue, then |AddReturnValue| must // have already been called to create the variable to store to. - void RecordReturnValue(ir::BasicBlock* block); + void RecordReturnValue(BasicBlock* block); // Adds an unconditional branch in |block| that branches to |target|. It also // adds stores to |return_flag_| and |return_value_| as needed. // |AddReturnFlag| and |AddReturnValue| must have already been called. - void BranchToBlock(ir::BasicBlock* block, uint32_t target); + void BranchToBlock(BasicBlock* block, uint32_t target); - // Returns true if we need to pridicate |block| where |tail_block| is the + // Returns true if we need to predicate |block| where |tail_block| is the // merge point. (See |PredicateBlocks|). There is no need to predicate if // there is no code that could be executed. - bool RequiresPredication(const ir::BasicBlock* block, - const ir::BasicBlock* tail_block) const; + bool RequiresPredication(const BasicBlock* block, + const BasicBlock* tail_block) const; - // For every basic block that is reachable from a basic block in - // |return_blocks|, extra code is added to jump around any code that should - // not be executed because the original code would have already returned. This - // involves adding new selections constructs to jump around these - // instructions. - void PredicateBlocks(const std::vector& return_blocks); + // For every basic block that is reachable from |return_block|, extra code is + // added to jump around any code that should not be executed because the + // original code would have already returned. This involves adding new + // selections constructs to jump around these instructions. + // + // If new blocks that are created will be added to |order|. This way a call + // can traverse these new block in structured order. + void PredicateBlocks(BasicBlock* return_block, + std::unordered_set* pSet, + std::list* order); + + // Add a conditional branch at the start of |block| that either jumps to + // |merge_block| or the original code in |block| depending on the value in + // |return_flag_|. + // + // If new blocks that are created will be added to |order|. This way a call + // can traverse these new block in structured order. + void BreakFromConstruct(BasicBlock* block, BasicBlock* merge_block, + std::unordered_set* predicated, + std::list* order); // Add the predication code (see |PredicateBlocks|) to |tail_block| if it // requires predication. |tail_block| and any new blocks that are known to // not require predication will be added to |predicated|. - void PredicateBlock(ir::BasicBlock* block, ir::BasicBlock* tail_block, - std::unordered_set* predicated); + // + // If new blocks that are created will be added to |order|. This way a call + // can traverse these new block in structured order. + void PredicateBlock(BasicBlock* block, BasicBlock* tail_block, + std::unordered_set* predicated, + std::list* order); // Add an |OpReturn| or |OpReturnValue| to the end of |block|. If an // |OpReturnValue| is needed, the return value is loaded from |return_value_|. - void CreateReturn(ir::BasicBlock* block); + void CreateReturn(BasicBlock* block); // Creates a block at the end of the function that will become the single // return block at the end of the pass. @@ -239,8 +258,8 @@ class MergeReturnPass : public MemPass { // |predecessor|. Any uses of the result of |inst| that are no longer // dominated by |inst|, are replaced with the result of the new |OpPhi| // instruction. - void CreatePhiNodesForInst(ir::BasicBlock* merge_block, uint32_t predecessor, - ir::Instruction& inst); + void CreatePhiNodesForInst(BasicBlock* merge_block, uint32_t predecessor, + Instruction& inst); // Traverse the nodes in |new_merge_nodes_|, and adds the OpPhi instructions // that are needed to make the code correct. It is assumed that at this point @@ -250,18 +269,16 @@ class MergeReturnPass : public MemPass { // Creates any new phi nodes that are needed in |bb| now that |pred| is no // longer the only block that preceedes |bb|. |header_id| is the id of the // basic block for the loop or selection construct that merges at |bb|. - void AddNewPhiNodes(ir::BasicBlock* bb, ir::BasicBlock* pred, - uint32_t header_id); + void AddNewPhiNodes(BasicBlock* bb, BasicBlock* pred, uint32_t header_id); // Saves |block| to a list of basic block that will require OpPhi nodes to be // added by calling |AddNewPhiNodes|. It is assumed that |block| used to have // a single predecessor, |single_original_pred|, but now has more. - void MarkForNewPhiNodes(ir::BasicBlock* block, - ir::BasicBlock* single_original_pred); + void MarkForNewPhiNodes(BasicBlock* block, BasicBlock* single_original_pred); // Return the original single predcessor of |block| if it was flagged as // having a single predecessor. |nullptr| is returned otherwise. - ir::BasicBlock* MarkedSinglePred(ir::BasicBlock* block) { + BasicBlock* MarkedSinglePred(BasicBlock* block) { auto it = new_merge_nodes_.find(block); if (it != new_merge_nodes_.end()) { return it->second; @@ -270,38 +287,50 @@ class MergeReturnPass : public MemPass { } } + // Modifies existing OpPhi instruction in |target| block to account for the + // new edge from |new_source|. The value for that edge will be an Undef. If + // |target| only had a single predecessor, then it is marked as needing new + // phi nodes. See |MarkForNewPhiNodes|. + void UpdatePhiNodes(BasicBlock* new_source, BasicBlock* target); + StructuredControlState& CurrentState() { return state_.back(); } + // Inserts |new_element| into |list| after the first occurrence of |element|. + // |element| must be in |list| at least once. + void InsertAfterElement(BasicBlock* element, BasicBlock* new_element, + std::list* list); + // A stack used to keep track of the innermost contain loop and selection // constructs. std::vector state_; // The current function being transformed. - ir::Function* function_; + Function* function_; // The |OpVariable| instruction defining a boolean variable used to keep track // of whether or not the function is trying to return. - ir::Instruction* return_flag_; + Instruction* return_flag_; // The |OpVariable| instruction defining a variabled to used to keep track of // the value that was returned when passing through a block that use to // contain an |OpReturnValue|. - ir::Instruction* return_value_; + Instruction* return_value_; // The instruction defining the boolean constant true. - ir::Instruction* constant_true_; + Instruction* constant_true_; // The basic block that is suppose to become the contain the only return value // after processing the current function. - ir::BasicBlock* final_return_block_; + BasicBlock* final_return_block_; + // This map contains the set of nodes that use to have a single predcessor, // but now have more. They will need new OpPhi nodes. For each of the nodes, // it is mapped to it original single predcessor. It is assumed there are no // values that will need a phi on the new edges. - std::unordered_map new_merge_nodes_; + std::unordered_map new_merge_nodes_; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_MERGE_RETURN_PASS_H_ +#endif // SOURCE_OPT_MERGE_RETURN_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/module.cpp b/3rdparty/spirv-tools/source/opt/module.cpp index 1e87c2c04..6d024b5bc 100644 --- a/3rdparty/spirv-tools/source/opt/module.cpp +++ b/3rdparty/spirv-tools/source/opt/module.cpp @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "module.h" +#include "source/opt/module.h" #include #include #include -#include "operand.h" -#include "reflect.h" +#include "source/operand.h" +#include "source/opt/reflect.h" namespace spvtools { -namespace ir { +namespace opt { std::vector Module::GetTypes() { std::vector type_insts; @@ -65,8 +65,8 @@ uint32_t Module::GetGlobalValue(SpvOp opcode) const { void Module::AddGlobalValue(SpvOp opcode, uint32_t result_id, uint32_t type_id) { - std::unique_ptr newGlobal( - new ir::Instruction(context(), opcode, type_id, result_id, {})); + std::unique_ptr newGlobal( + new Instruction(context(), opcode, type_id, result_id, {})); AddGlobalValue(std::move(newGlobal)); } @@ -160,7 +160,7 @@ uint32_t Module::GetExtInstImportId(const char* extstr) { } std::ostream& operator<<(std::ostream& str, const Module& module) { - module.ForEachInst([&str](const ir::Instruction* inst) { + module.ForEachInst([&str](const Instruction* inst) { str << *inst; if (inst->opcode() != SpvOpFunctionEnd) { str << std::endl; @@ -169,5 +169,5 @@ std::ostream& operator<<(std::ostream& str, const Module& module) { return str; } -} // namespace ir +} // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/module.h b/3rdparty/spirv-tools/source/opt/module.h index 163c4e30e..eca8cc779 100644 --- a/3rdparty/spirv-tools/source/opt/module.h +++ b/3rdparty/spirv-tools/source/opt/module.h @@ -12,20 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_MODULE_H_ -#define LIBSPIRV_OPT_MODULE_H_ +#ifndef SOURCE_OPT_MODULE_H_ +#define SOURCE_OPT_MODULE_H_ #include #include #include #include -#include "function.h" -#include "instruction.h" -#include "iterator.h" +#include "source/opt/function.h" +#include "source/opt/instruction.h" +#include "source/opt/iterator.h" namespace spvtools { -namespace ir { +namespace opt { class IRContext; @@ -470,7 +470,7 @@ inline Module::const_iterator Module::cend() const { return const_iterator(&functions_, functions_.cend()); } -} // namespace ir +} // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_MODULE_H_ +#endif // SOURCE_OPT_MODULE_H_ diff --git a/3rdparty/spirv-tools/source/opt/null_pass.h b/3rdparty/spirv-tools/source/opt/null_pass.h index 54ea06e35..2b5974fb9 100644 --- a/3rdparty/spirv-tools/source/opt/null_pass.h +++ b/3rdparty/spirv-tools/source/opt/null_pass.h @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_NULL_PASS_H_ -#define LIBSPIRV_OPT_NULL_PASS_H_ +#ifndef SOURCE_OPT_NULL_PASS_H_ +#define SOURCE_OPT_NULL_PASS_H_ -#include "module.h" -#include "pass.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -25,12 +25,10 @@ namespace opt { class NullPass : public Pass { public: const char* name() const override { return "null"; } - Status Process(ir::IRContext*) override { - return Status::SuccessWithoutChange; - } + Status Process() override { return Status::SuccessWithoutChange; } }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_NULL_PASS_H_ +#endif // SOURCE_OPT_NULL_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/optimizer.cpp b/3rdparty/spirv-tools/source/opt/optimizer.cpp index 65ab365a5..30654869c 100644 --- a/3rdparty/spirv-tools/source/opt/optimizer.cpp +++ b/3rdparty/spirv-tools/source/opt/optimizer.cpp @@ -14,11 +14,19 @@ #include "spirv-tools/optimizer.hpp" -#include "build_module.h" -#include "make_unique.h" -#include "pass_manager.h" -#include "passes.h" -#include "simplification_pass.h" +#include +#include +#include +#include +#include + +#include "source/opt/build_module.h" +#include "source/opt/log.h" +#include "source/opt/pass_manager.h" +#include "source/opt/passes.h" +#include "source/opt/reduce_load_size.h" +#include "source/opt/simplification_pass.h" +#include "source/util/make_unique.h" namespace spvtools { @@ -31,6 +39,10 @@ struct Optimizer::PassToken::Impl { Optimizer::PassToken::PassToken( std::unique_ptr impl) : impl_(std::move(impl)) {} + +Optimizer::PassToken::PassToken(std::unique_ptr&& pass) + : impl_(MakeUnique(std::move(pass))) {} + Optimizer::PassToken::PassToken(PassToken&& that) : impl_(std::move(that.impl_)) {} @@ -60,9 +72,13 @@ void Optimizer::SetMessageConsumer(MessageConsumer c) { impl_->pass_manager.SetMessageConsumer(std::move(c)); } +const MessageConsumer& Optimizer::consumer() const { + return impl_->pass_manager.consumer(); +} + Optimizer& Optimizer::RegisterPass(PassToken&& p) { // Change to use the pass manager's consumer. - p.impl_->pass->SetMessageConsumer(impl_->pass_manager.consumer()); + p.impl_->pass->SetMessageConsumer(consumer()); impl_->pass_manager.AddPass(std::move(p.impl_->pass)); return *this; } @@ -103,8 +119,8 @@ Optimizer& Optimizer::RegisterLegalizationPasses() { .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) .RegisterPass(CreateLocalSingleStoreElimPass()) .RegisterPass(CreateAggressiveDCEPass()) - // Split up aggragates so they are easier to deal with. - .RegisterPass(CreateScalarReplacementPass()) + // Split up aggregates so they are easier to deal with. + .RegisterPass(CreateScalarReplacementPass(0)) // Remove loads and stores so everything is in intermediate values. // Takes care of copy propagation of non-members. .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) @@ -119,22 +135,23 @@ Optimizer& Optimizer::RegisterLegalizationPasses() { // Copy propagate members. Cleans up code sequences generated by // scalar replacement. Also important for removing OpPhi nodes. .RegisterPass(CreateSimplificationPass()) - .RegisterPass(CreateInsertExtractElimPass()) .RegisterPass(CreateAggressiveDCEPass()) .RegisterPass(CreateCopyPropagateArraysPass()) // May need loop unrolling here see // https://github.com/Microsoft/DirectXShaderCompiler/pull/930 // Get rid of unused code that contain traces of illegal code // or unused references to unbound external objects + .RegisterPass(CreateVectorDCEPass()) .RegisterPass(CreateDeadInsertElimPass()) + .RegisterPass(CreateReduceLoadSizePass()) .RegisterPass(CreateAggressiveDCEPass()); } Optimizer& Optimizer::RegisterPerformancePasses() { - return RegisterPass(CreateRemoveDuplicatesPass()) - .RegisterPass(CreateMergeReturnPass()) + return RegisterPass(CreateMergeReturnPass()) .RegisterPass(CreateInlineExhaustivePass()) .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreatePrivateToLocalPass()) .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) .RegisterPass(CreateLocalSingleStoreElimPass()) .RegisterPass(CreateAggressiveDCEPass()) @@ -148,33 +165,36 @@ Optimizer& Optimizer::RegisterPerformancePasses() { .RegisterPass(CreateCCPPass()) .RegisterPass(CreateAggressiveDCEPass()) .RegisterPass(CreateRedundancyEliminationPass()) - .RegisterPass(CreateInsertExtractElimPass()) + .RegisterPass(CreateCombineAccessChainsPass()) + .RegisterPass(CreateSimplificationPass()) + .RegisterPass(CreateVectorDCEPass()) .RegisterPass(CreateDeadInsertElimPass()) .RegisterPass(CreateDeadBranchElimPass()) .RegisterPass(CreateSimplificationPass()) .RegisterPass(CreateIfConversionPass()) .RegisterPass(CreateCopyPropagateArraysPass()) + .RegisterPass(CreateReduceLoadSizePass()) .RegisterPass(CreateAggressiveDCEPass()) .RegisterPass(CreateBlockMergePass()) .RegisterPass(CreateRedundancyEliminationPass()) .RegisterPass(CreateDeadBranchElimPass()) .RegisterPass(CreateBlockMergePass()) - .RegisterPass(CreateInsertExtractElimPass()); + .RegisterPass(CreateSimplificationPass()); // Currently exposing driver bugs resulting in crashes (#946) // .RegisterPass(CreateCommonUniformElimPass()) } Optimizer& Optimizer::RegisterSizePasses() { - return RegisterPass(CreateRemoveDuplicatesPass()) - .RegisterPass(CreateMergeReturnPass()) + return RegisterPass(CreateMergeReturnPass()) .RegisterPass(CreateInlineExhaustivePass()) .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreatePrivateToLocalPass()) .RegisterPass(CreateScalarReplacementPass()) .RegisterPass(CreateLocalAccessChainConvertPass()) .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) .RegisterPass(CreateLocalSingleStoreElimPass()) .RegisterPass(CreateAggressiveDCEPass()) - .RegisterPass(CreateInsertExtractElimPass()) + .RegisterPass(CreateSimplificationPass()) .RegisterPass(CreateDeadInsertElimPass()) .RegisterPass(CreateLocalMultiStoreElimPass()) .RegisterPass(CreateAggressiveDCEPass()) @@ -184,7 +204,7 @@ Optimizer& Optimizer::RegisterSizePasses() { .RegisterPass(CreateIfConversionPass()) .RegisterPass(CreateAggressiveDCEPass()) .RegisterPass(CreateBlockMergePass()) - .RegisterPass(CreateInsertExtractElimPass()) + .RegisterPass(CreateSimplificationPass()) .RegisterPass(CreateDeadInsertElimPass()) .RegisterPass(CreateRedundancyEliminationPass()) .RegisterPass(CreateCFGCleanupPass()) @@ -193,12 +213,268 @@ Optimizer& Optimizer::RegisterSizePasses() { .RegisterPass(CreateAggressiveDCEPass()); } +bool Optimizer::RegisterPassesFromFlags(const std::vector& flags) { + for (const auto& flag : flags) { + if (!RegisterPassFromFlag(flag)) { + return false; + } + } + + return true; +} + +namespace { + +// Splits the string |flag|, of the form '--pass_name[=pass_args]' into two +// strings "pass_name" and "pass_args". If |flag| has no arguments, the second +// string will be empty. +std::pair SplitFlagArgs(const std::string& flag) { + if (flag.size() < 2) return make_pair(flag, std::string()); + + // Detect the last dash before the pass name. Since we have to + // handle single dash options (-O and -Os), count up to two dashes. + size_t dash_ix = 0; + if (flag[0] == '-' && flag[1] == '-') + dash_ix = 2; + else if (flag[0] == '-') + dash_ix = 1; + + size_t ix = flag.find('='); + return (ix != std::string::npos) + ? make_pair(flag.substr(dash_ix, ix - 2), flag.substr(ix + 1)) + : make_pair(flag.substr(dash_ix), std::string()); +} +} // namespace + +bool Optimizer::FlagHasValidForm(const std::string& flag) const { + if (flag == "-O" || flag == "-Os") { + return true; + } else if (flag.size() > 2 && flag.substr(0, 2) == "--") { + return true; + } + + Errorf(consumer(), nullptr, {}, + "%s is not a valid flag. Flag passes should have the form " + "'--pass_name[=pass_args]'. Special flag names also accepted: -O " + "and -Os.", + flag.c_str()); + return false; +} + +bool Optimizer::RegisterPassFromFlag(const std::string& flag) { + if (!FlagHasValidForm(flag)) { + return false; + } + + // Split flags of the form --pass_name=pass_args. + auto p = SplitFlagArgs(flag); + std::string pass_name = p.first; + std::string pass_args = p.second; + + // FIXME(dnovillo): This should be re-factored so that pass names can be + // automatically checked against Pass::name() and PassToken instances created + // via a template function. Additionally, class Pass should have a desc() + // method that describes the pass (so it can be used in --help). + // + // Both Pass::name() and Pass::desc() should be static class members so they + // can be invoked without creating a pass instance. + if (pass_name == "strip-debug") { + RegisterPass(CreateStripDebugInfoPass()); + } else if (pass_name == "strip-reflect") { + RegisterPass(CreateStripReflectInfoPass()); + } else if (pass_name == "set-spec-const-default-value") { + if (pass_args.size() > 0) { + auto spec_ids_vals = + opt::SetSpecConstantDefaultValuePass::ParseDefaultValuesString( + pass_args.c_str()); + if (!spec_ids_vals) { + Errorf(consumer(), nullptr, {}, + "Invalid argument for --set-spec-const-default-value: %s", + pass_args.c_str()); + return false; + } + RegisterPass( + CreateSetSpecConstantDefaultValuePass(std::move(*spec_ids_vals))); + } else { + Errorf(consumer(), nullptr, {}, + "Invalid spec constant value string '%s'. Expected a string of " + ": pairs.", + pass_args.c_str()); + return false; + } + } else if (pass_name == "if-conversion") { + RegisterPass(CreateIfConversionPass()); + } else if (pass_name == "freeze-spec-const") { + RegisterPass(CreateFreezeSpecConstantValuePass()); + } else if (pass_name == "inline-entry-points-exhaustive") { + RegisterPass(CreateInlineExhaustivePass()); + } else if (pass_name == "inline-entry-points-opaque") { + RegisterPass(CreateInlineOpaquePass()); + } else if (pass_name == "combine-access-chains") { + RegisterPass(CreateCombineAccessChainsPass()); + } else if (pass_name == "convert-local-access-chains") { + RegisterPass(CreateLocalAccessChainConvertPass()); + } else if (pass_name == "eliminate-dead-code-aggressive") { + RegisterPass(CreateAggressiveDCEPass()); + } else if (pass_name == "eliminate-insert-extract") { + RegisterPass(CreateInsertExtractElimPass()); + } else if (pass_name == "eliminate-local-single-block") { + RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()); + } else if (pass_name == "eliminate-local-single-store") { + RegisterPass(CreateLocalSingleStoreElimPass()); + } else if (pass_name == "merge-blocks") { + RegisterPass(CreateBlockMergePass()); + } else if (pass_name == "merge-return") { + RegisterPass(CreateMergeReturnPass()); + } else if (pass_name == "eliminate-dead-branches") { + RegisterPass(CreateDeadBranchElimPass()); + } else if (pass_name == "eliminate-dead-functions") { + RegisterPass(CreateEliminateDeadFunctionsPass()); + } else if (pass_name == "eliminate-local-multi-store") { + RegisterPass(CreateLocalMultiStoreElimPass()); + } else if (pass_name == "eliminate-common-uniform") { + RegisterPass(CreateCommonUniformElimPass()); + } else if (pass_name == "eliminate-dead-const") { + RegisterPass(CreateEliminateDeadConstantPass()); + } else if (pass_name == "eliminate-dead-inserts") { + RegisterPass(CreateDeadInsertElimPass()); + } else if (pass_name == "eliminate-dead-variables") { + RegisterPass(CreateDeadVariableEliminationPass()); + } else if (pass_name == "fold-spec-const-op-composite") { + RegisterPass(CreateFoldSpecConstantOpAndCompositePass()); + } else if (pass_name == "loop-unswitch") { + RegisterPass(CreateLoopUnswitchPass()); + } else if (pass_name == "scalar-replacement") { + if (pass_args.size() == 0) { + RegisterPass(CreateScalarReplacementPass()); + } else { + int limit = atoi(pass_args.c_str()); + if (limit > 0) { + RegisterPass(CreateScalarReplacementPass(limit)); + } else { + Error(consumer(), nullptr, {}, + "--scalar-replacement must have no arguments or a positive " + "integer argument"); + return false; + } + } + } else if (pass_name == "strength-reduction") { + RegisterPass(CreateStrengthReductionPass()); + } else if (pass_name == "unify-const") { + RegisterPass(CreateUnifyConstantPass()); + } else if (pass_name == "flatten-decorations") { + RegisterPass(CreateFlattenDecorationPass()); + } else if (pass_name == "compact-ids") { + RegisterPass(CreateCompactIdsPass()); + } else if (pass_name == "cfg-cleanup") { + RegisterPass(CreateCFGCleanupPass()); + } else if (pass_name == "local-redundancy-elimination") { + RegisterPass(CreateLocalRedundancyEliminationPass()); + } else if (pass_name == "loop-invariant-code-motion") { + RegisterPass(CreateLoopInvariantCodeMotionPass()); + } else if (pass_name == "reduce-load-size") { + RegisterPass(CreateReduceLoadSizePass()); + } else if (pass_name == "redundancy-elimination") { + RegisterPass(CreateRedundancyEliminationPass()); + } else if (pass_name == "private-to-local") { + RegisterPass(CreatePrivateToLocalPass()); + } else if (pass_name == "remove-duplicates") { + RegisterPass(CreateRemoveDuplicatesPass()); + } else if (pass_name == "workaround-1209") { + RegisterPass(CreateWorkaround1209Pass()); + } else if (pass_name == "replace-invalid-opcode") { + RegisterPass(CreateReplaceInvalidOpcodePass()); + } else if (pass_name == "simplify-instructions") { + RegisterPass(CreateSimplificationPass()); + } else if (pass_name == "ssa-rewrite") { + RegisterPass(CreateSSARewritePass()); + } else if (pass_name == "copy-propagate-arrays") { + RegisterPass(CreateCopyPropagateArraysPass()); + } else if (pass_name == "loop-fission") { + int register_threshold_to_split = + (pass_args.size() > 0) ? atoi(pass_args.c_str()) : -1; + if (register_threshold_to_split > 0) { + RegisterPass(CreateLoopFissionPass( + static_cast(register_threshold_to_split))); + } else { + Error(consumer(), nullptr, {}, + "--loop-fission must have a positive integer argument"); + return false; + } + } else if (pass_name == "loop-fusion") { + int max_registers_per_loop = + (pass_args.size() > 0) ? atoi(pass_args.c_str()) : -1; + if (max_registers_per_loop > 0) { + RegisterPass( + CreateLoopFusionPass(static_cast(max_registers_per_loop))); + } else { + Error(consumer(), nullptr, {}, + "--loop-fusion must have a positive integer argument"); + return false; + } + } else if (pass_name == "loop-unroll") { + RegisterPass(CreateLoopUnrollPass(true)); + } else if (pass_name == "vector-dce") { + RegisterPass(CreateVectorDCEPass()); + } else if (pass_name == "loop-unroll-partial") { + int factor = (pass_args.size() > 0) ? atoi(pass_args.c_str()) : 0; + if (factor > 0) { + RegisterPass(CreateLoopUnrollPass(false, factor)); + } else { + Error(consumer(), nullptr, {}, + "--loop-unroll-partial must have a positive integer argument"); + return false; + } + } else if (pass_name == "loop-peeling") { + RegisterPass(CreateLoopPeelingPass()); + } else if (pass_name == "loop-peeling-threshold") { + int factor = (pass_args.size() > 0) ? atoi(pass_args.c_str()) : 0; + if (factor > 0) { + opt::LoopPeelingPass::SetLoopPeelingThreshold(factor); + } else { + Error(consumer(), nullptr, {}, + "--loop-peeling-threshold must have a positive integer argument"); + return false; + } + } else if (pass_name == "ccp") { + RegisterPass(CreateCCPPass()); + } else if (pass_name == "O") { + RegisterPerformancePasses(); + } else if (pass_name == "Os") { + RegisterSizePasses(); + } else if (pass_name == "legalize-hlsl") { + RegisterLegalizationPasses(); + } else { + Errorf(consumer(), nullptr, {}, + "Unknown flag '--%s'. Use --help for a list of valid flags", + pass_name.c_str()); + return false; + } + + return true; +} + bool Optimizer::Run(const uint32_t* original_binary, const size_t original_binary_size, std::vector* optimized_binary) const { - std::unique_ptr context = - BuildModule(impl_->target_env, impl_->pass_manager.consumer(), - original_binary, original_binary_size); + return Run(original_binary, original_binary_size, optimized_binary, + ValidatorOptions()); +} + +bool Optimizer::Run(const uint32_t* original_binary, + const size_t original_binary_size, + std::vector* optimized_binary, + const ValidatorOptions& options, + bool skip_validation) const { + spvtools::SpirvTools tools(impl_->target_env); + tools.SetMessageConsumer(impl_->pass_manager.consumer()); + if (!skip_validation && + !tools.Validate(original_binary, original_binary_size, options)) { + return false; + } + + std::unique_ptr context = BuildModule( + impl_->target_env, consumer(), original_binary, original_binary_size); if (context == nullptr) return false; auto status = impl_->pass_manager.Run(context.get()); @@ -321,7 +597,7 @@ Optimizer::PassToken CreateLocalSingleStoreElimPass() { Optimizer::PassToken CreateInsertExtractElimPass() { return MakeUnique( - MakeUnique()); + MakeUnique()); } Optimizer::PassToken CreateDeadInsertElimPass() { @@ -377,10 +653,25 @@ Optimizer::PassToken CreateLocalRedundancyEliminationPass() { MakeUnique()); } +Optimizer::PassToken CreateLoopFissionPass(size_t threshold) { + return MakeUnique( + MakeUnique(threshold)); +} + +Optimizer::PassToken CreateLoopFusionPass(size_t max_registers_per_loop) { + return MakeUnique( + MakeUnique(max_registers_per_loop)); +} + Optimizer::PassToken CreateLoopInvariantCodeMotionPass() { return MakeUnique(MakeUnique()); } +Optimizer::PassToken CreateLoopPeelingPass() { + return MakeUnique( + MakeUnique()); +} + Optimizer::PassToken CreateLoopUnswitchPass() { return MakeUnique( MakeUnique()); @@ -396,9 +687,9 @@ Optimizer::PassToken CreateRemoveDuplicatesPass() { MakeUnique()); } -Optimizer::PassToken CreateScalarReplacementPass() { +Optimizer::PassToken CreateScalarReplacementPass(uint32_t size_limit) { return MakeUnique( - MakeUnique()); + MakeUnique(size_limit)); } Optimizer::PassToken CreatePrivateToLocalPass() { @@ -445,4 +736,17 @@ Optimizer::PassToken CreateCopyPropagateArraysPass() { MakeUnique()); } +Optimizer::PassToken CreateVectorDCEPass() { + return MakeUnique(MakeUnique()); +} + +Optimizer::PassToken CreateReduceLoadSizePass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateCombineAccessChainsPass() { + return MakeUnique( + MakeUnique()); +} } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/pass.cpp b/3rdparty/spirv-tools/source/opt/pass.cpp index c4a4befe4..4c4a232c6 100644 --- a/3rdparty/spirv-tools/source/opt/pass.cpp +++ b/3rdparty/spirv-tools/source/opt/pass.cpp @@ -14,9 +14,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass.h" +#include "source/opt/pass.h" -#include "iterator.h" +#include "source/opt/iterator.h" namespace spvtools { namespace opt { @@ -28,18 +28,18 @@ const uint32_t kTypePointerTypeIdInIdx = 1; } // namespace -Pass::Pass() : consumer_(nullptr), context_(nullptr) {} +Pass::Pass() : consumer_(nullptr), context_(nullptr), already_run_(false) {} -void Pass::AddCalls(ir::Function* func, std::queue* todo) { +void Pass::AddCalls(Function* func, std::queue* todo) { for (auto bi = func->begin(); bi != func->end(); ++bi) for (auto ii = bi->begin(); ii != bi->end(); ++ii) if (ii->opcode() == SpvOpFunctionCall) todo->push(ii->GetSingleWordInOperand(0)); } -bool Pass::ProcessEntryPointCallTree(ProcessFunction& pfn, ir::Module* module) { +bool Pass::ProcessEntryPointCallTree(ProcessFunction& pfn, Module* module) { // Map from function's result id to function - std::unordered_map id2function; + std::unordered_map id2function; for (auto& fn : *module) id2function[fn.result_id()] = &fn; // Collect all of the entry points as the roots. @@ -50,9 +50,9 @@ bool Pass::ProcessEntryPointCallTree(ProcessFunction& pfn, ir::Module* module) { } bool Pass::ProcessReachableCallTree(ProcessFunction& pfn, - ir::IRContext* irContext) { + IRContext* irContext) { // Map from function's result id to function - std::unordered_map id2function; + std::unordered_map id2function; for (auto& fn : *irContext->module()) id2function[fn.result_id()] = &fn; std::queue roots; @@ -84,7 +84,7 @@ bool Pass::ProcessReachableCallTree(ProcessFunction& pfn, bool Pass::ProcessCallTreeFromRoots( ProcessFunction& pfn, - const std::unordered_map& id2function, + const std::unordered_map& id2function, std::queue* roots) { // Process call tree bool modified = false; @@ -94,7 +94,7 @@ bool Pass::ProcessCallTreeFromRoots( const uint32_t fi = roots->front(); roots->pop(); if (done.insert(fi).second) { - ir::Function* fn = id2function.at(fi); + Function* fn = id2function.at(fi); modified = pfn(fn) || modified; AddCalls(fn, roots); } @@ -102,8 +102,16 @@ bool Pass::ProcessCallTreeFromRoots( return modified; } -Pass::Status Pass::Run(ir::IRContext* ctx) { - Pass::Status status = Process(ctx); +Pass::Status Pass::Run(IRContext* ctx) { + if (already_run_) { + return Status::Failure; + } + already_run_ = true; + + context_ = ctx; + Pass::Status status = Process(); + context_ = nullptr; + if (status == Status::SuccessWithChange) { ctx->InvalidateAnalysesExceptFor(GetPreservedAnalyses()); } @@ -111,9 +119,9 @@ Pass::Status Pass::Run(ir::IRContext* ctx) { return status; } -uint32_t Pass::GetPointeeTypeId(const ir::Instruction* ptrInst) const { +uint32_t Pass::GetPointeeTypeId(const Instruction* ptrInst) const { const uint32_t ptrTypeId = ptrInst->type_id(); - const ir::Instruction* ptrTypeInst = get_def_use_mgr()->GetDef(ptrTypeId); + const Instruction* ptrTypeInst = get_def_use_mgr()->GetDef(ptrTypeId); return ptrTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx); } diff --git a/3rdparty/spirv-tools/source/opt/pass.h b/3rdparty/spirv-tools/source/opt/pass.h index 3077112df..df1745099 100644 --- a/3rdparty/spirv-tools/source/opt/pass.h +++ b/3rdparty/spirv-tools/source/opt/pass.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_PASS_H_ -#define LIBSPIRV_OPT_PASS_H_ +#ifndef SOURCE_OPT_PASS_H_ +#define SOURCE_OPT_PASS_H_ #include #include @@ -22,10 +22,10 @@ #include #include -#include "basic_block.h" -#include "def_use_manager.h" -#include "ir_context.h" -#include "module.h" +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" #include "spirv-tools/libspirv.hpp" namespace spvtools { @@ -45,14 +45,7 @@ class Pass { SuccessWithoutChange = 0x11, }; - using ProcessFunction = std::function; - - // Constructs a new pass. - // - // The constructed instance will have an empty message consumer, which just - // ignores all messages from the library. Use SetMessageConsumer() to supply - // one if messages are of concern. - Pass(); + using ProcessFunction = std::function; // Destructs the pass. virtual ~Pass() = default; @@ -87,26 +80,29 @@ class Pass { } // Returns a pointer to the current module for this pass. - ir::Module* get_module() const { return context_->module(); } + Module* get_module() const { return context_->module(); } + + // Sets the pointer to the current context for this pass. + void SetContextForTesting(IRContext* ctx) { context_ = ctx; } // Returns a pointer to the current context for this pass. - ir::IRContext* context() const { return context_; } + IRContext* context() const { return context_; } // Returns a pointer to the CFG for current module. - ir::CFG* cfg() const { return context()->cfg(); } + CFG* cfg() const { return context()->cfg(); } // Add to |todo| all ids of functions called in |func|. - void AddCalls(ir::Function* func, std::queue* todo); + void AddCalls(Function* func, std::queue* todo); // Applies |pfn| to every function in the call trees that are rooted at the // entry points. Returns true if any call |pfn| returns true. By convention // |pfn| should return true if it modified the module. - bool ProcessEntryPointCallTree(ProcessFunction& pfn, ir::Module* module); + bool ProcessEntryPointCallTree(ProcessFunction& pfn, Module* module); // Applies |pfn| to every function in the call trees rooted at the entry // points and exported functions. Returns true if any call |pfn| returns // true. By convention |pfn| should return true if it modified the module. - bool ProcessReachableCallTree(ProcessFunction& pfn, ir::IRContext* irContext); + bool ProcessReachableCallTree(ProcessFunction& pfn, IRContext* irContext); // Applies |pfn| to every function in the call trees rooted at the elements of // |roots|. Returns true if any call to |pfn| returns true. By convention @@ -114,34 +110,39 @@ class Pass { // |roots| will be empty. bool ProcessCallTreeFromRoots( ProcessFunction& pfn, - const std::unordered_map& id2function, + const std::unordered_map& id2function, std::queue* roots); // Run the pass on the given |module|. Returns Status::Failure if errors occur - // when - // processing. Returns the corresponding Status::Success if processing is + // when processing. Returns the corresponding Status::Success if processing is // successful to indicate whether changes are made to the module. If there // were any changes it will also invalidate the analyses in the IRContext // that are not preserved. - virtual Status Run(ir::IRContext* ctx) final; + // + // It is an error if |Run| is called twice with the same instance of the pass. + // If this happens the return value will be |Failure|. + Status Run(IRContext* ctx); // Returns the set of analyses that the pass is guaranteed to preserve. - virtual ir::IRContext::Analysis GetPreservedAnalyses() { - return ir::IRContext::kAnalysisNone; + virtual IRContext::Analysis GetPreservedAnalyses() { + return IRContext::kAnalysisNone; } // Return type id for |ptrInst|'s pointee - uint32_t GetPointeeTypeId(const ir::Instruction* ptrInst) const; + uint32_t GetPointeeTypeId(const Instruction* ptrInst) const; protected: - // Initialize basic data structures for the pass. This sets up the def-use - // manager, module and other attributes. - virtual void InitializeProcessing(ir::IRContext* c) { context_ = c; } + // Constructs a new pass. + // + // The constructed instance will have an empty message consumer, which just + // ignores all messages from the library. Use SetMessageConsumer() to supply + // one if messages are of concern. + Pass(); // Processes the given |module|. Returns Status::Failure if errors occur when // processing. Returns the corresponding Status::Success if processing is // succesful to indicate whether changes are made to the module. - virtual Status Process(ir::IRContext* context) = 0; + virtual Status Process() = 0; // Return the next available SSA id and increment it. uint32_t TakeNextId() { return context_->TakeNextId(); } @@ -150,10 +151,15 @@ class Pass { MessageConsumer consumer_; // Message consumer. // The context that this pass belongs to. - ir::IRContext* context_; + IRContext* context_; + + // An instance of a pass can only be run once because it is too hard to + // enforce proper resetting of internal state for each instance. This member + // is used to check that we do not run the same instance twice. + bool already_run_; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_PASS_H_ +#endif // SOURCE_OPT_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/pass_manager.cpp b/3rdparty/spirv-tools/source/opt/pass_manager.cpp index c735bbb07..fa1e1d8a8 100644 --- a/3rdparty/spirv-tools/source/opt/pass_manager.cpp +++ b/3rdparty/spirv-tools/source/opt/pass_manager.cpp @@ -12,20 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_manager.h" +#include "source/opt/pass_manager.h" #include +#include #include -#include "ir_context.h" +#include "source/opt/ir_context.h" +#include "source/util/timer.h" #include "spirv-tools/libspirv.hpp" -#include "util/timer.h" namespace spvtools { namespace opt { -Pass::Status PassManager::Run(ir::IRContext* context) { +Pass::Status PassManager::Run(IRContext* context) { auto status = Pass::Status::SuccessWithoutChange; // If print_all_stream_ is not null, prints the disassembly of the module @@ -43,12 +44,15 @@ Pass::Status PassManager::Run(ir::IRContext* context) { }; SPIRV_TIMER_DESCRIPTION(time_report_stream_, /* measure_mem_usage = */ true); - for (const auto& pass : passes_) { + for (auto& pass : passes_) { print_disassembly("; IR before pass ", pass.get()); SPIRV_TIMER_SCOPED(time_report_stream_, (pass ? pass->name() : ""), true); const auto one_status = pass->Run(context); if (one_status == Pass::Status::Failure) return one_status; if (one_status == Pass::Status::SuccessWithChange) status = one_status; + + // Reset the pass to free any memory used by the pass. + pass.reset(nullptr); } print_disassembly("; IR after last pass", nullptr); diff --git a/3rdparty/spirv-tools/source/opt/pass_manager.h b/3rdparty/spirv-tools/source/opt/pass_manager.h index da3620a9e..ed88aa17c 100644 --- a/3rdparty/spirv-tools/source/opt/pass_manager.h +++ b/3rdparty/spirv-tools/source/opt/pass_manager.h @@ -12,18 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_PASS_MANAGER_H_ -#define LIBSPIRV_OPT_PASS_MANAGER_H_ +#ifndef SOURCE_OPT_PASS_MANAGER_H_ +#define SOURCE_OPT_PASS_MANAGER_H_ #include #include +#include #include -#include "log.h" -#include "module.h" -#include "pass.h" +#include "source/opt/log.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" -#include "ir_context.h" +#include "source/opt/ir_context.h" #include "spirv-tools/libspirv.hpp" namespace spvtools { @@ -70,7 +71,7 @@ class PassManager { // whether changes are made to the module. // // After running all the passes, they are removed from the list. - Pass::Status Run(ir::IRContext* context); + Pass::Status Run(IRContext* context); // Sets the option to print the disassembly before each pass and after the // last pass. Output is written to |out| if that is not null. No output @@ -127,4 +128,4 @@ inline const MessageConsumer& PassManager::consumer() const { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_PASS_MANAGER_H_ +#endif // SOURCE_OPT_PASS_MANAGER_H_ diff --git a/3rdparty/spirv-tools/source/opt/passes.h b/3rdparty/spirv-tools/source/opt/passes.h index eccda9c7e..42106c8f7 100644 --- a/3rdparty/spirv-tools/source/opt/passes.h +++ b/3rdparty/spirv-tools/source/opt/passes.h @@ -12,50 +12,56 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_PASSES_H_ -#define LIBSPIRV_OPT_PASSES_H_ +#ifndef SOURCE_OPT_PASSES_H_ +#define SOURCE_OPT_PASSES_H_ // A single header to include all passes. -#include "aggressive_dead_code_elim_pass.h" -#include "block_merge_pass.h" -#include "ccp_pass.h" -#include "cfg_cleanup_pass.h" -#include "common_uniform_elim_pass.h" -#include "compact_ids_pass.h" -#include "copy_prop_arrays.h" -#include "dead_branch_elim_pass.h" -#include "dead_insert_elim_pass.h" -#include "dead_variable_elimination.h" -#include "eliminate_dead_constant_pass.h" -#include "eliminate_dead_functions_pass.h" -#include "flatten_decoration_pass.h" -#include "fold_spec_constant_op_and_composite_pass.h" -#include "freeze_spec_constant_value_pass.h" -#include "if_conversion.h" -#include "inline_exhaustive_pass.h" -#include "inline_opaque_pass.h" -#include "insert_extract_elim.h" -#include "licm_pass.h" -#include "local_access_chain_convert_pass.h" -#include "local_redundancy_elimination.h" -#include "local_single_block_elim_pass.h" -#include "local_single_store_elim_pass.h" -#include "local_ssa_elim_pass.h" -#include "loop_unroller.h" -#include "loop_unswitch_pass.h" -#include "merge_return_pass.h" -#include "null_pass.h" -#include "private_to_local_pass.h" -#include "redundancy_elimination.h" -#include "remove_duplicates_pass.h" -#include "replace_invalid_opc.h" -#include "scalar_replacement_pass.h" -#include "set_spec_constant_default_value_pass.h" -#include "ssa_rewrite_pass.h" -#include "strength_reduction_pass.h" -#include "strip_debug_info_pass.h" -#include "strip_reflect_info_pass.h" -#include "unify_const_pass.h" -#include "workaround1209.h" -#endif // LIBSPIRV_OPT_PASSES_H_ +#include "source/opt/aggressive_dead_code_elim_pass.h" +#include "source/opt/block_merge_pass.h" +#include "source/opt/ccp_pass.h" +#include "source/opt/cfg_cleanup_pass.h" +#include "source/opt/combine_access_chains.h" +#include "source/opt/common_uniform_elim_pass.h" +#include "source/opt/compact_ids_pass.h" +#include "source/opt/copy_prop_arrays.h" +#include "source/opt/dead_branch_elim_pass.h" +#include "source/opt/dead_insert_elim_pass.h" +#include "source/opt/dead_variable_elimination.h" +#include "source/opt/eliminate_dead_constant_pass.h" +#include "source/opt/eliminate_dead_functions_pass.h" +#include "source/opt/flatten_decoration_pass.h" +#include "source/opt/fold_spec_constant_op_and_composite_pass.h" +#include "source/opt/freeze_spec_constant_value_pass.h" +#include "source/opt/if_conversion.h" +#include "source/opt/inline_exhaustive_pass.h" +#include "source/opt/inline_opaque_pass.h" +#include "source/opt/licm_pass.h" +#include "source/opt/local_access_chain_convert_pass.h" +#include "source/opt/local_redundancy_elimination.h" +#include "source/opt/local_single_block_elim_pass.h" +#include "source/opt/local_single_store_elim_pass.h" +#include "source/opt/local_ssa_elim_pass.h" +#include "source/opt/loop_fission.h" +#include "source/opt/loop_fusion_pass.h" +#include "source/opt/loop_peeling.h" +#include "source/opt/loop_unroller.h" +#include "source/opt/loop_unswitch_pass.h" +#include "source/opt/merge_return_pass.h" +#include "source/opt/null_pass.h" +#include "source/opt/private_to_local_pass.h" +#include "source/opt/reduce_load_size.h" +#include "source/opt/redundancy_elimination.h" +#include "source/opt/remove_duplicates_pass.h" +#include "source/opt/replace_invalid_opc.h" +#include "source/opt/scalar_replacement_pass.h" +#include "source/opt/set_spec_constant_default_value_pass.h" +#include "source/opt/ssa_rewrite_pass.h" +#include "source/opt/strength_reduction_pass.h" +#include "source/opt/strip_debug_info_pass.h" +#include "source/opt/strip_reflect_info_pass.h" +#include "source/opt/unify_const_pass.h" +#include "source/opt/vector_dce.h" +#include "source/opt/workaround1209.h" + +#endif // SOURCE_OPT_PASSES_H_ diff --git a/3rdparty/spirv-tools/source/opt/private_to_local_pass.cpp b/3rdparty/spirv-tools/source/opt/private_to_local_pass.cpp index cc8ef53c1..02909a72b 100644 --- a/3rdparty/spirv-tools/source/opt/private_to_local_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/private_to_local_pass.cpp @@ -12,20 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "private_to_local_pass.h" +#include "source/opt/private_to_local_pass.h" -#include "ir_context.h" +#include +#include +#include -namespace { -const uint32_t kVariableStorageClassInIdx = 0; -const uint32_t kSpvTypePointerTypeIdInIdx = 1; -} // namespace +#include "source/opt/ir_context.h" namespace spvtools { namespace opt { +namespace { -Pass::Status PrivateToLocalPass::Process(ir::IRContext* c) { - InitializeProcessing(c); +const uint32_t kVariableStorageClassInIdx = 0; +const uint32_t kSpvTypePointerTypeIdInIdx = 1; + +} // namespace + +Pass::Status PrivateToLocalPass::Process() { bool modified = false; // Private variables require the shader capability. If this is not a shader, @@ -33,7 +37,7 @@ Pass::Status PrivateToLocalPass::Process(ir::IRContext* c) { if (context()->get_feature_mgr()->HasCapability(SpvCapabilityAddresses)) return Status::SuccessWithoutChange; - std::vector> variables_to_move; + std::vector> variables_to_move; for (auto& inst : context()->types_values()) { if (inst.opcode() != SpvOpVariable) { continue; @@ -44,7 +48,7 @@ Pass::Status PrivateToLocalPass::Process(ir::IRContext* c) { continue; } - ir::Function* target_function = FindLocalFunction(inst); + Function* target_function = FindLocalFunction(inst); if (target_function != nullptr) { variables_to_move.push_back({&inst, target_function}); } @@ -58,14 +62,13 @@ Pass::Status PrivateToLocalPass::Process(ir::IRContext* c) { return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); } -ir::Function* PrivateToLocalPass::FindLocalFunction( - const ir::Instruction& inst) const { +Function* PrivateToLocalPass::FindLocalFunction(const Instruction& inst) const { bool found_first_use = false; - ir::Function* target_function = nullptr; + Function* target_function = nullptr; context()->get_def_use_mgr()->ForEachUser( inst.result_id(), - [&target_function, &found_first_use, this](ir::Instruction* use) { - ir::BasicBlock* current_block = context()->get_instr_block(use); + [&target_function, &found_first_use, this](Instruction* use) { + BasicBlock* current_block = context()->get_instr_block(use); if (current_block == nullptr) { return; } @@ -75,7 +78,7 @@ ir::Function* PrivateToLocalPass::FindLocalFunction( target_function = nullptr; return; } - ir::Function* current_function = current_block->GetParent(); + Function* current_function = current_block->GetParent(); if (!found_first_use) { found_first_use = true; target_function = current_function; @@ -86,12 +89,12 @@ ir::Function* PrivateToLocalPass::FindLocalFunction( return target_function; } // namespace opt -void PrivateToLocalPass::MoveVariable(ir::Instruction* variable, - ir::Function* function) { +void PrivateToLocalPass::MoveVariable(Instruction* variable, + Function* function) { // The variable needs to be removed from the global section, and placed in the // header of the function. First step remove from the global list. variable->RemoveFromList(); - std::unique_ptr var(variable); // Take ownership. + std::unique_ptr var(variable); // Take ownership. context()->ForgetUses(variable); // Update the storage class of the variable. @@ -103,6 +106,7 @@ void PrivateToLocalPass::MoveVariable(ir::Instruction* variable, // Place the variable at the start of the first basic block. context()->AnalyzeUses(variable); + context()->set_instr_block(variable, &*function->begin()); function->begin()->begin()->InsertBefore(move(var)); // Update uses where the type may have changed. @@ -111,15 +115,16 @@ void PrivateToLocalPass::MoveVariable(ir::Instruction* variable, uint32_t PrivateToLocalPass::GetNewType(uint32_t old_type_id) { auto type_mgr = context()->get_type_mgr(); - ir::Instruction* old_type_inst = get_def_use_mgr()->GetDef(old_type_id); + Instruction* old_type_inst = get_def_use_mgr()->GetDef(old_type_id); uint32_t pointee_type_id = old_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx); uint32_t new_type_id = type_mgr->FindPointerToType(pointee_type_id, SpvStorageClassFunction); + context()->UpdateDefUse(context()->get_def_use_mgr()->GetDef(new_type_id)); return new_type_id; } -bool PrivateToLocalPass::IsValidUse(const ir::Instruction* inst) const { +bool PrivateToLocalPass::IsValidUse(const Instruction* inst) const { // The cases in this switch have to match the cases in |UpdateUse|. // If we don't know how to update it, it is not valid. switch (inst->opcode()) { @@ -129,7 +134,7 @@ bool PrivateToLocalPass::IsValidUse(const ir::Instruction* inst) const { return true; case SpvOpAccessChain: return context()->get_def_use_mgr()->WhileEachUser( - inst, [this](const ir::Instruction* user) { + inst, [this](const Instruction* user) { if (!IsValidUse(user)) return false; return true; }); @@ -140,7 +145,7 @@ bool PrivateToLocalPass::IsValidUse(const ir::Instruction* inst) const { } } -void PrivateToLocalPass::UpdateUse(ir::Instruction* inst) { +void PrivateToLocalPass::UpdateUse(Instruction* inst) { // The cases in this switch have to match the cases in |IsValidUse|. If we // don't think it is valid, the optimization will not view the variable as a // candidate, and therefore the use will not be updated. @@ -168,11 +173,11 @@ void PrivateToLocalPass::UpdateUse(ir::Instruction* inst) { } } void PrivateToLocalPass::UpdateUses(uint32_t id) { - std::vector uses; - this->context()->get_def_use_mgr()->ForEachUser( - id, [&uses](ir::Instruction* use) { uses.push_back(use); }); + std::vector uses; + context()->get_def_use_mgr()->ForEachUser( + id, [&uses](Instruction* use) { uses.push_back(use); }); - for (ir::Instruction* use : uses) { + for (Instruction* use : uses) { UpdateUse(use); } } diff --git a/3rdparty/spirv-tools/source/opt/private_to_local_pass.h b/3rdparty/spirv-tools/source/opt/private_to_local_pass.h index 89cd994f1..f706e6e91 100644 --- a/3rdparty/spirv-tools/source/opt/private_to_local_pass.h +++ b/3rdparty/spirv-tools/source/opt/private_to_local_pass.h @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_PRIVATE_TO_LOCAL_PASS_H_ -#define LIBSPIRV_OPT_PRIVATE_TO_LOCAL_PASS_H_ +#ifndef SOURCE_OPT_PRIVATE_TO_LOCAL_PASS_H_ +#define SOURCE_OPT_PRIVATE_TO_LOCAL_PASS_H_ -#include "ir_context.h" -#include "pass.h" +#include "source/opt/ir_context.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -28,30 +28,31 @@ namespace opt { class PrivateToLocalPass : public Pass { public: const char* name() const override { return "private-to-local"; } - Status Process(ir::IRContext*) override; - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisDecorations | - ir::IRContext::kAnalysisCombinators | ir::IRContext::kAnalysisCFG | - ir::IRContext::kAnalysisDominatorAnalysis | - ir::IRContext::kAnalysisNameMap; + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap; } private: // Moves |variable| from the private storage class to the function storage // class of |function|. - void MoveVariable(ir::Instruction* variable, ir::Function* function); + void MoveVariable(Instruction* variable, Function* function); // |inst| is an instruction declaring a varible. If that variable is // referenced in a single function and all of uses are valid as defined by // |IsValidUse|, then that function is returned. Otherwise, the return // value is |nullptr|. - ir::Function* FindLocalFunction(const ir::Instruction& inst) const; + Function* FindLocalFunction(const Instruction& inst) const; // Returns true is |inst| is a valid use of a pointer. In this case, a // valid use is one where the transformation is able to rewrite the type to // match a change in storage class of the original variable. - bool IsValidUse(const ir::Instruction* inst) const; + bool IsValidUse(const Instruction* inst) const; // Given the result id of a pointer type, |old_type_id|, this function // returns the id of a the same pointer type except the storage class has @@ -61,11 +62,11 @@ class PrivateToLocalPass : public Pass { // Updates |inst|, and any instruction dependent on |inst|, to reflect the // change of the base pointer now pointing to the function storage class. - void UpdateUse(ir::Instruction* inst); + void UpdateUse(Instruction* inst); void UpdateUses(uint32_t id); }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_PRIVATE_TO_LOCAL_PASS_H_ +#endif // SOURCE_OPT_PRIVATE_TO_LOCAL_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/propagator.cpp b/3rdparty/spirv-tools/source/opt/propagator.cpp index d5d76127a..6a1f1aafb 100644 --- a/3rdparty/spirv-tools/source/opt/propagator.cpp +++ b/3rdparty/spirv-tools/source/opt/propagator.cpp @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "propagator.h" +#include "source/opt/propagator.h" namespace spvtools { namespace opt { void SSAPropagator::AddControlEdge(const Edge& edge) { - ir::BasicBlock* dest_bb = edge.dest; + BasicBlock* dest_bb = edge.dest; // Refuse to add the exit block to the work list. if (dest_bb == ctx_->cfg()->pseudo_exit_block()) { @@ -36,14 +36,14 @@ void SSAPropagator::AddControlEdge(const Edge& edge) { blocks_.push(dest_bb); } -void SSAPropagator::AddSSAEdges(ir::Instruction* instr) { +void SSAPropagator::AddSSAEdges(Instruction* instr) { // Ignore instructions that produce no result. if (instr->result_id() == 0) { return; } get_def_use_mgr()->ForEachUser( - instr->result_id(), [this](ir::Instruction* use_instr) { + instr->result_id(), [this](Instruction* use_instr) { // If the basic block for |use_instr| has not been simulated yet, do // nothing. The instruction |use_instr| will be simulated next time the // block is scheduled. @@ -57,17 +57,17 @@ void SSAPropagator::AddSSAEdges(ir::Instruction* instr) { }); } -bool SSAPropagator::IsPhiArgExecutable(ir::Instruction* phi, uint32_t i) const { - ir::BasicBlock* phi_bb = ctx_->get_instr_block(phi); +bool SSAPropagator::IsPhiArgExecutable(Instruction* phi, uint32_t i) const { + BasicBlock* phi_bb = ctx_->get_instr_block(phi); uint32_t in_label_id = phi->GetSingleWordOperand(i + 1); - ir::Instruction* in_label_instr = get_def_use_mgr()->GetDef(in_label_id); - ir::BasicBlock* in_bb = ctx_->get_instr_block(in_label_instr); + Instruction* in_label_instr = get_def_use_mgr()->GetDef(in_label_id); + BasicBlock* in_bb = ctx_->get_instr_block(in_label_instr); return IsEdgeExecutable(Edge(in_bb, phi_bb)); } -bool SSAPropagator::SetStatus(ir::Instruction* inst, PropStatus status) { +bool SSAPropagator::SetStatus(Instruction* inst, PropStatus status) { bool has_old_status = false; PropStatus old_status = kVarying; if (HasStatus(inst)) { @@ -84,7 +84,7 @@ bool SSAPropagator::SetStatus(ir::Instruction* inst, PropStatus status) { return status_changed; } -bool SSAPropagator::Simulate(ir::Instruction* instr) { +bool SSAPropagator::Simulate(Instruction* instr) { bool changed = false; // Don't bother visiting instructions that should not be simulated again. @@ -92,7 +92,7 @@ bool SSAPropagator::Simulate(ir::Instruction* instr) { return changed; } - ir::BasicBlock* dest_bb = nullptr; + BasicBlock* dest_bb = nullptr; PropStatus status = visit_fn_(instr, &dest_bb); bool status_changed = SetStatus(instr, status); @@ -107,7 +107,7 @@ bool SSAPropagator::Simulate(ir::Instruction* instr) { // If |instr| is a block terminator, add all the control edges out of its // block. if (instr->IsBlockTerminator()) { - ir::BasicBlock* block = ctx_->get_instr_block(instr); + BasicBlock* block = ctx_->get_instr_block(instr); for (const auto& e : bb_succs_.at(block)) { AddControlEdge(e); } @@ -145,7 +145,7 @@ bool SSAPropagator::Simulate(ir::Instruction* instr) { "malformed Phi arguments"); uint32_t arg_id = instr->GetSingleWordOperand(i); - ir::Instruction* arg_def_instr = get_def_use_mgr()->GetDef(arg_id); + Instruction* arg_def_instr = get_def_use_mgr()->GetDef(arg_id); if (!IsPhiArgExecutable(instr, i) || ShouldSimulateAgain(arg_def_instr)) { has_operands_to_simulate = true; break; @@ -157,7 +157,7 @@ bool SSAPropagator::Simulate(ir::Instruction* instr) { // also be simulated again. has_operands_to_simulate = !instr->WhileEachInId([this](const uint32_t* use) { - ir::Instruction* def_instr = get_def_use_mgr()->GetDef(*use); + Instruction* def_instr = get_def_use_mgr()->GetDef(*use); if (ShouldSimulateAgain(def_instr)) { return false; } @@ -172,7 +172,7 @@ bool SSAPropagator::Simulate(ir::Instruction* instr) { return changed; } -bool SSAPropagator::Simulate(ir::BasicBlock* block) { +bool SSAPropagator::Simulate(BasicBlock* block) { if (block == ctx_->cfg()->pseudo_exit_block()) { return false; } @@ -183,12 +183,12 @@ bool SSAPropagator::Simulate(ir::BasicBlock* block) { // operand can be simulated. bool changed = false; block->ForEachPhiInst( - [&changed, this](ir::Instruction* instr) { changed |= Simulate(instr); }); + [&changed, this](Instruction* instr) { changed |= Simulate(instr); }); // If this is the first time this block is being simulated, simulate every // statement in it. if (!BlockHasBeenSimulated(block)) { - block->ForEachInst([this, &changed](ir::Instruction* instr) { + block->ForEachInst([this, &changed](Instruction* instr) { if (instr->opcode() != SpvOpPhi) { changed |= Simulate(instr); } @@ -206,9 +206,9 @@ bool SSAPropagator::Simulate(ir::BasicBlock* block) { return changed; } -void SSAPropagator::Initialize(ir::Function* fn) { +void SSAPropagator::Initialize(Function* fn) { // Compute predecessor and successor blocks for every block in |fn|'s CFG. - // TODO(dnovillo): Move this to ir::CFG and always build them. Alternately, + // TODO(dnovillo): Move this to CFG and always build them. Alternately, // move it to IRContext and build CFG preds/succs on-demand. bb_succs_[ctx_->cfg()->pseudo_entry_block()].push_back( Edge(ctx_->cfg()->pseudo_entry_block(), fn->entry().get())); @@ -216,7 +216,7 @@ void SSAPropagator::Initialize(ir::Function* fn) { for (auto& block : *fn) { const auto& const_block = block; const_block.ForEachSuccessorLabel([this, &block](const uint32_t label_id) { - ir::BasicBlock* succ_bb = + BasicBlock* succ_bb = ctx_->get_instr_block(get_def_use_mgr()->GetDef(label_id)); bb_succs_[&block].push_back(Edge(&block, succ_bb)); bb_preds_[succ_bb].push_back(Edge(succ_bb, &block)); @@ -236,7 +236,7 @@ void SSAPropagator::Initialize(ir::Function* fn) { } } -bool SSAPropagator::Run(ir::Function* fn) { +bool SSAPropagator::Run(Function* fn) { Initialize(fn); bool changed = false; @@ -252,7 +252,7 @@ bool SSAPropagator::Run(ir::Function* fn) { // Simulate edges from the SSA queue. if (!ssa_edge_uses_.empty()) { - ir::Instruction* instr = ssa_edge_uses_.front(); + Instruction* instr = ssa_edge_uses_.front(); changed |= Simulate(instr); ssa_edge_uses_.pop(); } @@ -261,7 +261,7 @@ bool SSAPropagator::Run(ir::Function* fn) { #ifndef NDEBUG // Verify all visited values have settled. No value that has been simulated // should end on not interesting. - fn->ForEachInst([this](ir::Instruction* inst) { + fn->ForEachInst([this](Instruction* inst) { assert( (!HasStatus(inst) || Status(inst) != SSAPropagator::kNotInteresting) && "Unsettled value"); diff --git a/3rdparty/spirv-tools/source/opt/propagator.h b/3rdparty/spirv-tools/source/opt/propagator.h index f81690c77..ac7c0e7ea 100644 --- a/3rdparty/spirv-tools/source/opt/propagator.h +++ b/3rdparty/spirv-tools/source/opt/propagator.h @@ -12,30 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_PROPAGATOR_H_ -#define LIBSPIRV_OPT_PROPAGATOR_H_ +#ifndef SOURCE_OPT_PROPAGATOR_H_ +#define SOURCE_OPT_PROPAGATOR_H_ #include #include #include #include #include +#include #include -#include "ir_context.h" -#include "module.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" namespace spvtools { namespace opt { // Represents a CFG control edge. struct Edge { - Edge(ir::BasicBlock* b1, ir::BasicBlock* b2) : source(b1), dest(b2) { + Edge(BasicBlock* b1, BasicBlock* b2) : source(b1), dest(b2) { assert(source && "CFG edges cannot have a null source block."); assert(dest && "CFG edges cannot have a null destination block."); } - ir::BasicBlock* source; - ir::BasicBlock* dest; + BasicBlock* source; + BasicBlock* dest; bool operator<(const Edge& o) const { return std::make_pair(source->id(), dest->id()) < std::make_pair(o.source->id(), o.dest->id()); @@ -148,22 +149,22 @@ struct Edge { // following code builds a table |values| where every id that was assigned a // constant value is mapped to the constant value it was assigned. // -// auto ctx = spvtools::BuildModule(...); +// auto ctx = BuildModule(...); // std::map values; -// const auto visit_fn = [&ctx, &values](ir::Instruction* instr, -// ir::BasicBlock** dest_bb) { +// const auto visit_fn = [&ctx, &values](Instruction* instr, +// BasicBlock** dest_bb) { // if (instr->opcode() == SpvOpStore) { // uint32_t rhs_id = instr->GetSingleWordOperand(1); -// ir::Instruction* rhs_def = ctx->get_def_use_mgr()->GetDef(rhs_id); +// Instruction* rhs_def = ctx->get_def_use_mgr()->GetDef(rhs_id); // if (rhs_def->opcode() == SpvOpConstant) { // uint32_t val = rhs_def->GetSingleWordOperand(2); // values[rhs_id] = val; -// return opt::SSAPropagator::kInteresting; +// return SSAPropagator::kInteresting; // } // } -// return opt::SSAPropagator::kVarying; +// return SSAPropagator::kVarying; // }; -// opt::SSAPropagator propagator(ctx.get(), &cfg, visit_fn); +// SSAPropagator propagator(ctx.get(), &cfg, visit_fn); // propagator.Run(&fn); // // Given the code: @@ -183,66 +184,63 @@ class SSAPropagator { // a description. enum PropStatus { kNotInteresting, kInteresting, kVarying }; - using VisitFunction = - std::function; + using VisitFunction = std::function; - SSAPropagator(ir::IRContext* context, const VisitFunction& visit_fn) + SSAPropagator(IRContext* context, const VisitFunction& visit_fn) : ctx_(context), visit_fn_(visit_fn) {} // Runs the propagator on function |fn|. Returns true if changes were made to // the function. Otherwise, it returns false. - bool Run(ir::Function* fn); + bool Run(Function* fn); // Returns true if the |i|th argument for |phi| comes through a CFG edge that // has been marked executable. |i| should be an index value accepted by // Instruction::GetSingleWordOperand. - bool IsPhiArgExecutable(ir::Instruction* phi, uint32_t i) const; + bool IsPhiArgExecutable(Instruction* phi, uint32_t i) const; // Returns true if |inst| has a recorded status. This will be true once |inst| // has been simulated once. - bool HasStatus(ir::Instruction* inst) const { return statuses_.count(inst); } + bool HasStatus(Instruction* inst) const { return statuses_.count(inst); } // Returns the current propagation status of |inst|. Assumes // |HasStatus(inst)| returns true. - PropStatus Status(ir::Instruction* inst) const { + PropStatus Status(Instruction* inst) const { return statuses_.find(inst)->second; } // Records the propagation status |status| for |inst|. Returns true if the // status for |inst| has changed or set was set for the first time. - bool SetStatus(ir::Instruction* inst, PropStatus status); + bool SetStatus(Instruction* inst, PropStatus status); private: // Initialize processing. - void Initialize(ir::Function* fn); + void Initialize(Function* fn); // Simulate the execution |block| by calling |visit_fn_| on every instruction // in it. - bool Simulate(ir::BasicBlock* block); + bool Simulate(BasicBlock* block); // Simulate the execution of |instr| by replacing all the known values in // every operand and determining whether the result is interesting for // propagation. This invokes the callback function |visit_fn_| to determine // the value computed by |instr|. - bool Simulate(ir::Instruction* instr); + bool Simulate(Instruction* instr); // Returns true if |instr| should be simulated again. - bool ShouldSimulateAgain(ir::Instruction* instr) const { + bool ShouldSimulateAgain(Instruction* instr) const { return do_not_simulate_.find(instr) == do_not_simulate_.end(); } // Add |instr| to the set of instructions not to simulate again. - void DontSimulateAgain(ir::Instruction* instr) { - do_not_simulate_.insert(instr); - } + void DontSimulateAgain(Instruction* instr) { do_not_simulate_.insert(instr); } // Returns true if |block| has been simulated already. - bool BlockHasBeenSimulated(ir::BasicBlock* block) const { + bool BlockHasBeenSimulated(BasicBlock* block) const { return simulated_blocks_.find(block) != simulated_blocks_.end(); } // Marks block |block| as simulated. - void MarkBlockSimulated(ir::BasicBlock* block) { + void MarkBlockSimulated(BasicBlock* block) { simulated_blocks_.insert(block); } @@ -268,10 +266,10 @@ class SSAPropagator { // Adds all the instructions that use the result of |instr| to the SSA edges // work list. If |instr| produces no result id, this does nothing. - void AddSSAEdges(ir::Instruction* instr); + void AddSSAEdges(Instruction* instr); // IR context to use. - ir::IRContext* ctx_; + IRContext* ctx_; // Function that visits instructions during simulation. The output of this // function is used to determine if the simulated instruction produced a value @@ -281,33 +279,33 @@ class SSAPropagator { // SSA def-use edges to traverse. Each entry is a destination statement for an // SSA def-use edge as returned by |def_use_manager_|. - std::queue ssa_edge_uses_; + std::queue ssa_edge_uses_; // Blocks to simulate. - std::queue blocks_; + std::queue blocks_; // Blocks simulated during propagation. - std::unordered_set simulated_blocks_; + std::unordered_set simulated_blocks_; // Set of instructions that should not be simulated again because they have // been found to be in the kVarying state. - std::unordered_set do_not_simulate_; + std::unordered_set do_not_simulate_; // Map between a basic block and its predecessor edges. - // TODO(dnovillo): Move this to ir::CFG and always build them. Alternately, + // TODO(dnovillo): Move this to CFG and always build them. Alternately, // move it to IRContext and build CFG preds/succs on-demand. - std::unordered_map> bb_preds_; + std::unordered_map> bb_preds_; // Map between a basic block and its successor edges. - // TODO(dnovillo): Move this to ir::CFG and always build them. Alternately, + // TODO(dnovillo): Move this to CFG and always build them. Alternately, // move it to IRContext and build CFG preds/succs on-demand. - std::unordered_map> bb_succs_; + std::unordered_map> bb_succs_; // Set of executable CFG edges. std::set executable_edges_; // Tracks instruction propagation status. - std::unordered_map statuses_; + std::unordered_map statuses_; }; std::ostream& operator<<(std::ostream& str, @@ -316,4 +314,4 @@ std::ostream& operator<<(std::ostream& str, } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_PROPAGATOR_H_ +#endif // SOURCE_OPT_PROPAGATOR_H_ diff --git a/3rdparty/spirv-tools/source/opt/reduce_load_size.cpp b/3rdparty/spirv-tools/source/opt/reduce_load_size.cpp new file mode 100644 index 000000000..b692c6b54 --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/reduce_load_size.cpp @@ -0,0 +1,181 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/reduce_load_size.h" + +#include +#include + +#include "source/opt/instruction.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/util/bit_vector.h" + +namespace { + +const uint32_t kExtractCompositeIdInIdx = 0; +const uint32_t kVariableStorageClassInIdx = 0; +const uint32_t kLoadPointerInIdx = 0; +const double kThreshold = 0.9; + +} // namespace + +namespace spvtools { +namespace opt { + +Pass::Status ReduceLoadSize::Process() { + bool modified = false; + + for (auto& func : *get_module()) { + func.ForEachInst([&modified, this](Instruction* inst) { + if (inst->opcode() == SpvOpCompositeExtract) { + if (ShouldReplaceExtract(inst)) { + modified |= ReplaceExtract(inst); + } + } + }); + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool ReduceLoadSize::ReplaceExtract(Instruction* inst) { + assert(inst->opcode() == SpvOpCompositeExtract && + "Wrong opcode. Should be OpCompositeExtract."); + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + + uint32_t composite_id = + inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); + Instruction* composite_inst = def_use_mgr->GetDef(composite_id); + + if (composite_inst->opcode() != SpvOpLoad) { + return false; + } + + analysis::Type* composite_type = type_mgr->GetType(composite_inst->type_id()); + if (composite_type->kind() == analysis::Type::kVector || + composite_type->kind() == analysis::Type::kMatrix) { + return false; + } + + Instruction* var = composite_inst->GetBaseAddress(); + if (var == nullptr || var->opcode() != SpvOpVariable) { + return false; + } + + SpvStorageClass storage_class = static_cast( + var->GetSingleWordInOperand(kVariableStorageClassInIdx)); + switch (storage_class) { + case SpvStorageClassUniform: + case SpvStorageClassUniformConstant: + case SpvStorageClassInput: + break; + default: + return false; + } + + // Create a new access chain and load just after the old load. + // We cannot create the new access chain load in the position of the extract + // because the storage may have been written to in between. + InstructionBuilder ir_builder( + inst->context(), composite_inst, + IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDefUse); + + uint32_t pointer_to_result_type_id = + type_mgr->FindPointerToType(inst->type_id(), storage_class); + assert(pointer_to_result_type_id != 0 && + "We did not find the pointer type that we need."); + + analysis::Integer int_type(32, false); + const analysis::Type* uint32_type = type_mgr->GetRegisteredType(&int_type); + std::vector ids; + for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { + uint32_t index = inst->GetSingleWordInOperand(i); + const analysis::Constant* index_const = + const_mgr->GetConstant(uint32_type, {index}); + ids.push_back(const_mgr->GetDefiningInstruction(index_const)->result_id()); + } + + Instruction* new_access_chain = ir_builder.AddAccessChain( + pointer_to_result_type_id, + composite_inst->GetSingleWordInOperand(kLoadPointerInIdx), ids); + Instruction* new_laod = + ir_builder.AddLoad(inst->type_id(), new_access_chain->result_id()); + + context()->ReplaceAllUsesWith(inst->result_id(), new_laod->result_id()); + context()->KillInst(inst); + return true; +} + +bool ReduceLoadSize::ShouldReplaceExtract(Instruction* inst) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + Instruction* op_inst = def_use_mgr->GetDef( + inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)); + + if (op_inst->opcode() != SpvOpLoad) { + return false; + } + + auto cached_result = should_replace_cache_.find(op_inst->result_id()); + if (cached_result != should_replace_cache_.end()) { + return cached_result->second; + } + + bool all_elements_used = false; + std::set elements_used; + + all_elements_used = + !def_use_mgr->WhileEachUser(op_inst, [&elements_used](Instruction* use) { + if (use->opcode() != SpvOpCompositeExtract) { + return false; + } + elements_used.insert(use->GetSingleWordInOperand(1)); + return true; + }); + + bool should_replace = false; + if (all_elements_used) { + should_replace = false; + } else { + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::Type* load_type = type_mgr->GetType(op_inst->type_id()); + uint32_t total_size = 1; + switch (load_type->kind()) { + case analysis::Type::kArray: { + const analysis::Constant* size_const = + const_mgr->FindDeclaredConstant(load_type->AsArray()->LengthId()); + assert(size_const->AsIntConstant()); + total_size = size_const->GetU32(); + } break; + case analysis::Type::kStruct: + total_size = static_cast( + load_type->AsStruct()->element_types().size()); + break; + default: + break; + } + double percent_used = static_cast(elements_used.size()) / + static_cast(total_size); + should_replace = (percent_used < kThreshold); + } + + should_replace_cache_[op_inst->result_id()] = should_replace; + return should_replace; +} + +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/reduce_load_size.h b/3rdparty/spirv-tools/source/opt/reduce_load_size.h new file mode 100644 index 000000000..724a430bb --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/reduce_load_size.h @@ -0,0 +1,64 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_REDUCE_LOAD_SIZE_H_ +#define SOURCE_OPT_REDUCE_LOAD_SIZE_H_ + +#include + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class ReduceLoadSize : public Pass { + public: + const char* name() const override { return "reduce-load-size"; } + Status Process() override; + + // Return the mask of preserved Analyses. + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisCombinators | IRContext::kAnalysisCFG | + IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap; + } + + private: + // Replaces |inst|, which must be an OpCompositeExtract instruction, with + // an OpAccessChain and a load if possible. This happens only if it is a load + // feeding |inst|. Returns true if the substitution happened. The position + // of the new instructions will be in the same place as the load feeding the + // extract. + bool ReplaceExtract(Instruction* inst); + + // Returns true if the OpCompositeExtract instruction |inst| should be replace + // or not. This is determined by looking at the load that feeds |inst| if + // it is a load. |should_replace_cache_| is used to cache the results based + // on the load feeding |inst|. + bool ShouldReplaceExtract(Instruction* inst); + + // Maps the result id of an OpLoad instruction to the result of whether or + // not the OpCompositeExtract that use the id should be replaced. + std::unordered_map should_replace_cache_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_REDUCE_LOAD_SIZE_H_ diff --git a/3rdparty/spirv-tools/source/opt/redundancy_elimination.cpp b/3rdparty/spirv-tools/source/opt/redundancy_elimination.cpp index 7e5c01d5f..362e54dc6 100644 --- a/3rdparty/spirv-tools/source/opt/redundancy_elimination.cpp +++ b/3rdparty/spirv-tools/source/opt/redundancy_elimination.cpp @@ -12,24 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "redundancy_elimination.h" +#include "source/opt/redundancy_elimination.h" -#include "value_number_table.h" +#include "source/opt/value_number_table.h" namespace spvtools { namespace opt { -Pass::Status RedundancyEliminationPass::Process(ir::IRContext* c) { - InitializeProcessing(c); - +Pass::Status RedundancyEliminationPass::Process() { bool modified = false; ValueNumberTable vnTable(context()); for (auto& func : *get_module()) { // Build the dominator tree for this function. It is how the code is // traversed. - opt::DominatorTree& dom_tree = - context()->GetDominatorAnalysis(&func, *context()->cfg())->GetDomTree(); + DominatorTree& dom_tree = + context()->GetDominatorAnalysis(&func)->GetDomTree(); // Keeps track of all ids that contain a given value number. We keep // track of multiple values because they could have the same value, but diff --git a/3rdparty/spirv-tools/source/opt/redundancy_elimination.h b/3rdparty/spirv-tools/source/opt/redundancy_elimination.h index 634ecc330..91809b5d5 100644 --- a/3rdparty/spirv-tools/source/opt/redundancy_elimination.h +++ b/3rdparty/spirv-tools/source/opt/redundancy_elimination.h @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_REDUNDANCY_ELIMINATION_H_ -#define LIBSPIRV_OPT_REDUNDANCY_ELIMINATION_H_ +#ifndef SOURCE_OPT_REDUNDANCY_ELIMINATION_H_ +#define SOURCE_OPT_REDUNDANCY_ELIMINATION_H_ -#include "ir_context.h" -#include "local_redundancy_elimination.h" -#include "pass.h" -#include "value_number_table.h" +#include + +#include "source/opt/ir_context.h" +#include "source/opt/local_redundancy_elimination.h" +#include "source/opt/pass.h" +#include "source/opt/value_number_table.h" namespace spvtools { namespace opt { @@ -30,7 +32,7 @@ namespace opt { class RedundancyEliminationPass : public LocalRedundancyEliminationPass { public: const char* name() const override { return "redundancy-elimination"; } - Status Process(ir::IRContext*) override; + Status Process() override; protected: // Removes for all total redundancies in the function starting at |bb|. @@ -51,4 +53,4 @@ class RedundancyEliminationPass : public LocalRedundancyEliminationPass { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_REDUNDANCY_ELIMINATION_H_ +#endif // SOURCE_OPT_REDUNDANCY_ELIMINATION_H_ diff --git a/3rdparty/spirv-tools/source/opt/reflect.h b/3rdparty/spirv-tools/source/opt/reflect.h index ef2d84947..fb2de7b15 100644 --- a/3rdparty/spirv-tools/source/opt/reflect.h +++ b/3rdparty/spirv-tools/source/opt/reflect.h @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_REFLECT_H_ -#define LIBSPIRV_OPT_REFLECT_H_ +#ifndef SOURCE_OPT_REFLECT_H_ +#define SOURCE_OPT_REFLECT_H_ -#include "latest_version_spirv_header.h" +#include "source/latest_version_spirv_header.h" namespace spvtools { -namespace ir { +namespace opt { // Note that as SPIR-V evolves over time, new opcodes may appear. So the // following functions tend to be outdated and should be updated when SPIR-V @@ -59,7 +59,7 @@ inline bool IsTerminatorInst(SpvOp opcode) { return opcode >= SpvOpBranch && opcode <= SpvOpUnreachable; } -} // namespace ir +} // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_REFLECT_H_ +#endif // SOURCE_OPT_REFLECT_H_ diff --git a/3rdparty/spirv-tools/source/opt/register_pressure.cpp b/3rdparty/spirv-tools/source/opt/register_pressure.cpp new file mode 100644 index 000000000..34dac1d7b --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/register_pressure.cpp @@ -0,0 +1,576 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/register_pressure.h" + +#include +#include + +#include "source/opt/cfg.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/dominator_tree.h" +#include "source/opt/function.h" +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" + +namespace spvtools { +namespace opt { + +namespace { +// Predicate for the FilterIterator to only consider instructions that are not +// phi instructions defined in the basic block |bb|. +class ExcludePhiDefinedInBlock { + public: + ExcludePhiDefinedInBlock(IRContext* context, const BasicBlock* bb) + : context_(context), bb_(bb) {} + + bool operator()(Instruction* insn) const { + return !(insn->opcode() == SpvOpPhi && + context_->get_instr_block(insn) == bb_); + } + + private: + IRContext* context_; + const BasicBlock* bb_; +}; + +// Returns true if |insn| generates a SSA register that is likely to require a +// physical register. +bool CreatesRegisterUsage(Instruction* insn) { + if (!insn->HasResultId()) return false; + if (insn->opcode() == SpvOpUndef) return false; + if (IsConstantInst(insn->opcode())) return false; + if (insn->opcode() == SpvOpLabel) return false; + return true; +} + +// Compute the register liveness for each basic block of a function. This also +// fill-up some information about the pick register usage and a break down of +// register usage. This implements: "A non-iterative data-flow algorithm for +// computing liveness sets in strict ssa programs" from Boissinot et al. +class ComputeRegisterLiveness { + public: + ComputeRegisterLiveness(RegisterLiveness* reg_pressure, Function* f) + : reg_pressure_(reg_pressure), + context_(reg_pressure->GetContext()), + function_(f), + cfg_(*reg_pressure->GetContext()->cfg()), + def_use_manager_(*reg_pressure->GetContext()->get_def_use_mgr()), + dom_tree_( + reg_pressure->GetContext()->GetDominatorAnalysis(f)->GetDomTree()), + loop_desc_(*reg_pressure->GetContext()->GetLoopDescriptor(f)) {} + + // Computes the register liveness for |function_| and then estimate the + // register usage. The liveness algorithm works in 2 steps: + // - First, compute the liveness for each basic blocks, but will ignore any + // back-edge; + // - Second, walk loop forest to propagate registers crossing back-edges + // (add iterative values into the liveness set). + void Compute() { + cfg_.ForEachBlockInPostOrder(&*function_->begin(), [this](BasicBlock* bb) { + ComputePartialLiveness(bb); + }); + DoLoopLivenessUnification(); + EvaluateRegisterRequirements(); + } + + private: + // Registers all SSA register used by successors of |bb| in their phi + // instructions. + void ComputePhiUses(const BasicBlock& bb, + RegisterLiveness::RegionRegisterLiveness::LiveSet* live) { + uint32_t bb_id = bb.id(); + bb.ForEachSuccessorLabel([live, bb_id, this](uint32_t sid) { + BasicBlock* succ_bb = cfg_.block(sid); + succ_bb->ForEachPhiInst([live, bb_id, this](const Instruction* phi) { + for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { + if (phi->GetSingleWordInOperand(i + 1) == bb_id) { + Instruction* insn_op = + def_use_manager_.GetDef(phi->GetSingleWordInOperand(i)); + if (CreatesRegisterUsage(insn_op)) { + live->insert(insn_op); + break; + } + } + } + }); + }); + } + + // Computes register liveness for each basic blocks but ignores all + // back-edges. + void ComputePartialLiveness(BasicBlock* bb) { + assert(reg_pressure_->Get(bb) == nullptr && + "Basic block already processed"); + + RegisterLiveness::RegionRegisterLiveness* live_inout = + reg_pressure_->GetOrInsert(bb->id()); + ComputePhiUses(*bb, &live_inout->live_out_); + + const BasicBlock* cbb = bb; + cbb->ForEachSuccessorLabel([&live_inout, bb, this](uint32_t sid) { + // Skip back edges. + if (dom_tree_.Dominates(sid, bb->id())) { + return; + } + + BasicBlock* succ_bb = cfg_.block(sid); + RegisterLiveness::RegionRegisterLiveness* succ_live_inout = + reg_pressure_->Get(succ_bb); + assert(succ_live_inout && + "Successor liveness analysis was not performed"); + + ExcludePhiDefinedInBlock predicate(context_, succ_bb); + auto filter = + MakeFilterIteratorRange(succ_live_inout->live_in_.begin(), + succ_live_inout->live_in_.end(), predicate); + live_inout->live_out_.insert(filter.begin(), filter.end()); + }); + + live_inout->live_in_ = live_inout->live_out_; + for (Instruction& insn : make_range(bb->rbegin(), bb->rend())) { + if (insn.opcode() == SpvOpPhi) { + live_inout->live_in_.insert(&insn); + break; + } + live_inout->live_in_.erase(&insn); + insn.ForEachInId([live_inout, this](uint32_t* id) { + Instruction* insn_op = def_use_manager_.GetDef(*id); + if (CreatesRegisterUsage(insn_op)) { + live_inout->live_in_.insert(insn_op); + } + }); + } + } + + // Propagates the register liveness information of each loop iterators. + void DoLoopLivenessUnification() { + for (const Loop* loop : *loop_desc_.GetDummyRootLoop()) { + DoLoopLivenessUnification(*loop); + } + } + + // Propagates the register liveness information of loop iterators trough-out + // the loop body. + void DoLoopLivenessUnification(const Loop& loop) { + auto blocks_in_loop = MakeFilterIteratorRange( + loop.GetBlocks().begin(), loop.GetBlocks().end(), + [&loop, this](uint32_t bb_id) { + return bb_id != loop.GetHeaderBlock()->id() && + loop_desc_[bb_id] == &loop; + }); + + RegisterLiveness::RegionRegisterLiveness* header_live_inout = + reg_pressure_->Get(loop.GetHeaderBlock()); + assert(header_live_inout && + "Liveness analysis was not performed for the current block"); + + ExcludePhiDefinedInBlock predicate(context_, loop.GetHeaderBlock()); + auto live_loop = + MakeFilterIteratorRange(header_live_inout->live_in_.begin(), + header_live_inout->live_in_.end(), predicate); + + for (uint32_t bb_id : blocks_in_loop) { + BasicBlock* bb = cfg_.block(bb_id); + + RegisterLiveness::RegionRegisterLiveness* live_inout = + reg_pressure_->Get(bb); + live_inout->live_in_.insert(live_loop.begin(), live_loop.end()); + live_inout->live_out_.insert(live_loop.begin(), live_loop.end()); + } + + for (const Loop* inner_loop : loop) { + RegisterLiveness::RegionRegisterLiveness* live_inout = + reg_pressure_->Get(inner_loop->GetHeaderBlock()); + live_inout->live_in_.insert(live_loop.begin(), live_loop.end()); + live_inout->live_out_.insert(live_loop.begin(), live_loop.end()); + + DoLoopLivenessUnification(*inner_loop); + } + } + + // Get the number of required registers for this each basic block. + void EvaluateRegisterRequirements() { + for (BasicBlock& bb : *function_) { + RegisterLiveness::RegionRegisterLiveness* live_inout = + reg_pressure_->Get(bb.id()); + assert(live_inout != nullptr && "Basic block not processed"); + + size_t reg_count = live_inout->live_out_.size(); + for (Instruction* insn : live_inout->live_out_) { + live_inout->AddRegisterClass(insn); + } + live_inout->used_registers_ = reg_count; + + std::unordered_set die_in_block; + for (Instruction& insn : make_range(bb.rbegin(), bb.rend())) { + // If it is a phi instruction, the register pressure will not change + // anymore. + if (insn.opcode() == SpvOpPhi) { + break; + } + + insn.ForEachInId( + [live_inout, &die_in_block, ®_count, this](uint32_t* id) { + Instruction* op_insn = def_use_manager_.GetDef(*id); + if (!CreatesRegisterUsage(op_insn) || + live_inout->live_out_.count(op_insn)) { + // already taken into account. + return; + } + if (!die_in_block.count(*id)) { + live_inout->AddRegisterClass(def_use_manager_.GetDef(*id)); + reg_count++; + die_in_block.insert(*id); + } + }); + live_inout->used_registers_ = + std::max(live_inout->used_registers_, reg_count); + if (CreatesRegisterUsage(&insn)) { + reg_count--; + } + } + } + } + + RegisterLiveness* reg_pressure_; + IRContext* context_; + Function* function_; + CFG& cfg_; + analysis::DefUseManager& def_use_manager_; + DominatorTree& dom_tree_; + LoopDescriptor& loop_desc_; +}; +} // namespace + +// Get the number of required registers for each basic block. +void RegisterLiveness::RegionRegisterLiveness::AddRegisterClass( + Instruction* insn) { + assert(CreatesRegisterUsage(insn) && "Instruction does not use a register"); + analysis::Type* type = + insn->context()->get_type_mgr()->GetType(insn->type_id()); + + RegisterLiveness::RegisterClass reg_class{type, false}; + + insn->context()->get_decoration_mgr()->WhileEachDecoration( + insn->result_id(), SpvDecorationUniform, + [®_class](const Instruction&) { + reg_class.is_uniform_ = true; + return false; + }); + + AddRegisterClass(reg_class); +} + +void RegisterLiveness::Analyze(Function* f) { + block_pressure_.clear(); + ComputeRegisterLiveness(this, f).Compute(); +} + +void RegisterLiveness::ComputeLoopRegisterPressure( + const Loop& loop, RegionRegisterLiveness* loop_reg_pressure) const { + loop_reg_pressure->Clear(); + + const RegionRegisterLiveness* header_live_inout = Get(loop.GetHeaderBlock()); + loop_reg_pressure->live_in_ = header_live_inout->live_in_; + + std::unordered_set exit_blocks; + loop.GetExitBlocks(&exit_blocks); + + for (uint32_t bb_id : exit_blocks) { + const RegionRegisterLiveness* live_inout = Get(bb_id); + loop_reg_pressure->live_out_.insert(live_inout->live_in_.begin(), + live_inout->live_in_.end()); + } + + std::unordered_set seen_insn; + for (Instruction* insn : loop_reg_pressure->live_out_) { + loop_reg_pressure->AddRegisterClass(insn); + seen_insn.insert(insn->result_id()); + } + for (Instruction* insn : loop_reg_pressure->live_in_) { + if (!seen_insn.count(insn->result_id())) { + continue; + } + loop_reg_pressure->AddRegisterClass(insn); + seen_insn.insert(insn->result_id()); + } + + loop_reg_pressure->used_registers_ = 0; + + for (uint32_t bb_id : loop.GetBlocks()) { + BasicBlock* bb = context_->cfg()->block(bb_id); + + const RegionRegisterLiveness* live_inout = Get(bb_id); + assert(live_inout != nullptr && "Basic block not processed"); + loop_reg_pressure->used_registers_ = std::max( + loop_reg_pressure->used_registers_, live_inout->used_registers_); + + for (Instruction& insn : *bb) { + if (insn.opcode() == SpvOpPhi || !CreatesRegisterUsage(&insn) || + seen_insn.count(insn.result_id())) { + continue; + } + loop_reg_pressure->AddRegisterClass(&insn); + } + } +} + +void RegisterLiveness::SimulateFusion( + const Loop& l1, const Loop& l2, RegionRegisterLiveness* sim_result) const { + sim_result->Clear(); + + // Compute the live-in state: + // sim_result.live_in = l1.live_in U l2.live_in + // This assumes that |l1| does not generated register that is live-out for + // |l1|. + const RegionRegisterLiveness* l1_header_live_inout = Get(l1.GetHeaderBlock()); + sim_result->live_in_ = l1_header_live_inout->live_in_; + + const RegionRegisterLiveness* l2_header_live_inout = Get(l2.GetHeaderBlock()); + sim_result->live_in_.insert(l2_header_live_inout->live_in_.begin(), + l2_header_live_inout->live_in_.end()); + + // The live-out set of the fused loop is the l2 live-out set. + std::unordered_set exit_blocks; + l2.GetExitBlocks(&exit_blocks); + + for (uint32_t bb_id : exit_blocks) { + const RegionRegisterLiveness* live_inout = Get(bb_id); + sim_result->live_out_.insert(live_inout->live_in_.begin(), + live_inout->live_in_.end()); + } + + // Compute the register usage information. + std::unordered_set seen_insn; + for (Instruction* insn : sim_result->live_out_) { + sim_result->AddRegisterClass(insn); + seen_insn.insert(insn->result_id()); + } + for (Instruction* insn : sim_result->live_in_) { + if (!seen_insn.count(insn->result_id())) { + continue; + } + sim_result->AddRegisterClass(insn); + seen_insn.insert(insn->result_id()); + } + + sim_result->used_registers_ = 0; + + // The loop fusion is injecting the l1 before the l2, the latch of l1 will be + // connected to the header of l2. + // To compute the register usage, we inject the loop live-in (union of l1 and + // l2 live-in header blocks) into the the live in/out of each basic block of + // l1 to get the peak register usage. We then repeat the operation to for l2 + // basic blocks but in this case we inject the live-out of the latch of l1. + auto live_loop = MakeFilterIteratorRange( + sim_result->live_in_.begin(), sim_result->live_in_.end(), + [&l1, &l2](Instruction* insn) { + BasicBlock* bb = insn->context()->get_instr_block(insn); + return insn->HasResultId() && + !(insn->opcode() == SpvOpPhi && + (bb == l1.GetHeaderBlock() || bb == l2.GetHeaderBlock())); + }); + + for (uint32_t bb_id : l1.GetBlocks()) { + BasicBlock* bb = context_->cfg()->block(bb_id); + + const RegionRegisterLiveness* live_inout_info = Get(bb_id); + assert(live_inout_info != nullptr && "Basic block not processed"); + RegionRegisterLiveness::LiveSet live_out = live_inout_info->live_out_; + live_out.insert(live_loop.begin(), live_loop.end()); + sim_result->used_registers_ = + std::max(sim_result->used_registers_, + live_inout_info->used_registers_ + live_out.size() - + live_inout_info->live_out_.size()); + + for (Instruction& insn : *bb) { + if (insn.opcode() == SpvOpPhi || !CreatesRegisterUsage(&insn) || + seen_insn.count(insn.result_id())) { + continue; + } + sim_result->AddRegisterClass(&insn); + } + } + + const RegionRegisterLiveness* l1_latch_live_inout_info = + Get(l1.GetLatchBlock()->id()); + assert(l1_latch_live_inout_info != nullptr && "Basic block not processed"); + RegionRegisterLiveness::LiveSet l1_latch_live_out = + l1_latch_live_inout_info->live_out_; + l1_latch_live_out.insert(live_loop.begin(), live_loop.end()); + + auto live_loop_l2 = + make_range(l1_latch_live_out.begin(), l1_latch_live_out.end()); + + for (uint32_t bb_id : l2.GetBlocks()) { + BasicBlock* bb = context_->cfg()->block(bb_id); + + const RegionRegisterLiveness* live_inout_info = Get(bb_id); + assert(live_inout_info != nullptr && "Basic block not processed"); + RegionRegisterLiveness::LiveSet live_out = live_inout_info->live_out_; + live_out.insert(live_loop_l2.begin(), live_loop_l2.end()); + sim_result->used_registers_ = + std::max(sim_result->used_registers_, + live_inout_info->used_registers_ + live_out.size() - + live_inout_info->live_out_.size()); + + for (Instruction& insn : *bb) { + if (insn.opcode() == SpvOpPhi || !CreatesRegisterUsage(&insn) || + seen_insn.count(insn.result_id())) { + continue; + } + sim_result->AddRegisterClass(&insn); + } + } +} + +void RegisterLiveness::SimulateFission( + const Loop& loop, const std::unordered_set& moved_inst, + const std::unordered_set& copied_inst, + RegionRegisterLiveness* l1_sim_result, + RegionRegisterLiveness* l2_sim_result) const { + l1_sim_result->Clear(); + l2_sim_result->Clear(); + + // Filter predicates: consider instructions that only belong to the first and + // second loop. + auto belong_to_loop1 = [&moved_inst, &copied_inst, &loop](Instruction* insn) { + return moved_inst.count(insn) || copied_inst.count(insn) || + !loop.IsInsideLoop(insn); + }; + auto belong_to_loop2 = [&moved_inst](Instruction* insn) { + return !moved_inst.count(insn); + }; + + const RegionRegisterLiveness* header_live_inout = Get(loop.GetHeaderBlock()); + // l1 live-in + { + auto live_loop = MakeFilterIteratorRange( + header_live_inout->live_in_.begin(), header_live_inout->live_in_.end(), + belong_to_loop1); + l1_sim_result->live_in_.insert(live_loop.begin(), live_loop.end()); + } + // l2 live-in + { + auto live_loop = MakeFilterIteratorRange( + header_live_inout->live_in_.begin(), header_live_inout->live_in_.end(), + belong_to_loop2); + l2_sim_result->live_in_.insert(live_loop.begin(), live_loop.end()); + } + + std::unordered_set exit_blocks; + loop.GetExitBlocks(&exit_blocks); + + // l2 live-out. + for (uint32_t bb_id : exit_blocks) { + const RegionRegisterLiveness* live_inout = Get(bb_id); + l2_sim_result->live_out_.insert(live_inout->live_in_.begin(), + live_inout->live_in_.end()); + } + // l1 live-out. + { + auto live_out = MakeFilterIteratorRange(l2_sim_result->live_out_.begin(), + l2_sim_result->live_out_.end(), + belong_to_loop1); + l1_sim_result->live_out_.insert(live_out.begin(), live_out.end()); + } + { + auto live_out = + MakeFilterIteratorRange(l2_sim_result->live_in_.begin(), + l2_sim_result->live_in_.end(), belong_to_loop1); + l1_sim_result->live_out_.insert(live_out.begin(), live_out.end()); + } + // Lives out of l1 are live out of l2 so are live in of l2 as well. + l2_sim_result->live_in_.insert(l1_sim_result->live_out_.begin(), + l1_sim_result->live_out_.end()); + + for (Instruction* insn : l1_sim_result->live_in_) { + l1_sim_result->AddRegisterClass(insn); + } + for (Instruction* insn : l2_sim_result->live_in_) { + l2_sim_result->AddRegisterClass(insn); + } + + l1_sim_result->used_registers_ = 0; + l2_sim_result->used_registers_ = 0; + + for (uint32_t bb_id : loop.GetBlocks()) { + BasicBlock* bb = context_->cfg()->block(bb_id); + + const RegisterLiveness::RegionRegisterLiveness* live_inout = Get(bb_id); + assert(live_inout != nullptr && "Basic block not processed"); + auto l1_block_live_out = + MakeFilterIteratorRange(live_inout->live_out_.begin(), + live_inout->live_out_.end(), belong_to_loop1); + auto l2_block_live_out = + MakeFilterIteratorRange(live_inout->live_out_.begin(), + live_inout->live_out_.end(), belong_to_loop2); + + size_t l1_reg_count = + std::distance(l1_block_live_out.begin(), l1_block_live_out.end()); + size_t l2_reg_count = + std::distance(l2_block_live_out.begin(), l2_block_live_out.end()); + + std::unordered_set die_in_block; + for (Instruction& insn : make_range(bb->rbegin(), bb->rend())) { + if (insn.opcode() == SpvOpPhi) { + break; + } + + bool does_belong_to_loop1 = belong_to_loop1(&insn); + bool does_belong_to_loop2 = belong_to_loop2(&insn); + insn.ForEachInId([live_inout, &die_in_block, &l1_reg_count, &l2_reg_count, + does_belong_to_loop1, does_belong_to_loop2, + this](uint32_t* id) { + Instruction* op_insn = context_->get_def_use_mgr()->GetDef(*id); + if (!CreatesRegisterUsage(op_insn) || + live_inout->live_out_.count(op_insn)) { + // already taken into account. + return; + } + if (!die_in_block.count(*id)) { + if (does_belong_to_loop1) { + l1_reg_count++; + } + if (does_belong_to_loop2) { + l2_reg_count++; + } + die_in_block.insert(*id); + } + }); + l1_sim_result->used_registers_ = + std::max(l1_sim_result->used_registers_, l1_reg_count); + l2_sim_result->used_registers_ = + std::max(l2_sim_result->used_registers_, l2_reg_count); + if (CreatesRegisterUsage(&insn)) { + if (does_belong_to_loop1) { + if (!l1_sim_result->live_in_.count(&insn)) { + l1_sim_result->AddRegisterClass(&insn); + } + l1_reg_count--; + } + if (does_belong_to_loop2) { + if (!l2_sim_result->live_in_.count(&insn)) { + l2_sim_result->AddRegisterClass(&insn); + } + l2_reg_count--; + } + } + } + } +} + +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/register_pressure.h b/3rdparty/spirv-tools/source/opt/register_pressure.h new file mode 100644 index 000000000..cb3d2e270 --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/register_pressure.h @@ -0,0 +1,196 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_REGISTER_PRESSURE_H_ +#define SOURCE_OPT_REGISTER_PRESSURE_H_ + +#include +#include +#include +#include + +#include "source/opt/function.h" +#include "source/opt/types.h" + +namespace spvtools { +namespace opt { + +class IRContext; +class Loop; +class LoopDescriptor; + +// Handles the register pressure of a function for different regions (function, +// loop, basic block). It also contains some utilities to foresee the register +// pressure following code transformations. +class RegisterLiveness { + public: + // Classification of SSA registers. + struct RegisterClass { + analysis::Type* type_; + bool is_uniform_; + + bool operator==(const RegisterClass& rhs) const { + return std::tie(type_, is_uniform_) == + std::tie(rhs.type_, rhs.is_uniform_); + } + }; + + struct RegionRegisterLiveness { + using LiveSet = std::unordered_set; + using RegClassSetTy = std::vector>; + + // SSA register live when entering the basic block. + LiveSet live_in_; + // SSA register live when exiting the basic block. + LiveSet live_out_; + + // Maximum number of required registers. + size_t used_registers_; + // Break down of the number of required registers per class of register. + RegClassSetTy registers_classes_; + + void Clear() { + live_out_.clear(); + live_in_.clear(); + used_registers_ = 0; + registers_classes_.clear(); + } + + void AddRegisterClass(const RegisterClass& reg_class) { + auto it = std::find_if( + registers_classes_.begin(), registers_classes_.end(), + [®_class](const std::pair& class_count) { + return class_count.first == reg_class; + }); + if (it != registers_classes_.end()) { + it->second++; + } else { + registers_classes_.emplace_back(std::move(reg_class), + static_cast(1)); + } + } + + void AddRegisterClass(Instruction* insn); + }; + + RegisterLiveness(IRContext* context, Function* f) : context_(context) { + Analyze(f); + } + + // Returns liveness and register information for the basic block |bb|. If no + // entry exist for the basic block, the function returns null. + const RegionRegisterLiveness* Get(const BasicBlock* bb) const { + return Get(bb->id()); + } + + // Returns liveness and register information for the basic block id |bb_id|. + // If no entry exist for the basic block, the function returns null. + const RegionRegisterLiveness* Get(uint32_t bb_id) const { + RegionRegisterLivenessMap::const_iterator it = block_pressure_.find(bb_id); + if (it != block_pressure_.end()) { + return &it->second; + } + return nullptr; + } + + IRContext* GetContext() const { return context_; } + + // Returns liveness and register information for the basic block |bb|. If no + // entry exist for the basic block, the function returns null. + RegionRegisterLiveness* Get(const BasicBlock* bb) { return Get(bb->id()); } + + // Returns liveness and register information for the basic block id |bb_id|. + // If no entry exist for the basic block, the function returns null. + RegionRegisterLiveness* Get(uint32_t bb_id) { + RegionRegisterLivenessMap::iterator it = block_pressure_.find(bb_id); + if (it != block_pressure_.end()) { + return &it->second; + } + return nullptr; + } + + // Returns liveness and register information for the basic block id |bb_id| or + // create a new empty entry if no entry already existed. + RegionRegisterLiveness* GetOrInsert(uint32_t bb_id) { + return &block_pressure_[bb_id]; + } + + // Compute the register pressure for the |loop| and store the result into + // |reg_pressure|. The live-in set corresponds to the live-in set of the + // header block, the live-out set of the loop corresponds to the union of the + // live-in sets of each exit basic block. + void ComputeLoopRegisterPressure(const Loop& loop, + RegionRegisterLiveness* reg_pressure) const; + + // Estimate the register pressure for the |l1| and |l2| as if they were making + // one unique loop. The result is stored into |simulation_result|. + void SimulateFusion(const Loop& l1, const Loop& l2, + RegionRegisterLiveness* simulation_result) const; + + // Estimate the register pressure of |loop| after it has been fissioned + // according to |moved_instructions| and |copied_instructions|. The function + // assumes that the fission creates a new loop before |loop|, moves any + // instructions present inside |moved_instructions| and copies any + // instructions present inside |copied_instructions| into this new loop. + // The set |loop1_sim_result| store the simulation result of the loop with the + // moved instructions. The set |loop2_sim_result| store the simulation result + // of the loop with the removed instructions. + void SimulateFission( + const Loop& loop, + const std::unordered_set& moved_instructions, + const std::unordered_set& copied_instructions, + RegionRegisterLiveness* loop1_sim_result, + RegionRegisterLiveness* loop2_sim_result) const; + + private: + using RegionRegisterLivenessMap = + std::unordered_map; + + IRContext* context_; + RegionRegisterLivenessMap block_pressure_; + + void Analyze(Function* f); +}; + +// Handles the register pressure of a function for different regions (function, +// loop, basic block). It also contains some utilities to foresee the register +// pressure following code transformations. +class LivenessAnalysis { + using LivenessAnalysisMap = + std::unordered_map; + + public: + LivenessAnalysis(IRContext* context) : context_(context) {} + + // Computes the liveness analysis for the function |f| and cache the result. + // If the analysis was performed for this function, then the cached analysis + // is returned. + const RegisterLiveness* Get(Function* f) { + LivenessAnalysisMap::iterator it = analysis_cache_.find(f); + if (it != analysis_cache_.end()) { + return &it->second; + } + return &analysis_cache_.emplace(f, RegisterLiveness{context_, f}) + .first->second; + } + + private: + IRContext* context_; + LivenessAnalysisMap analysis_cache_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_REGISTER_PRESSURE_H_ diff --git a/3rdparty/spirv-tools/source/opt/remove_duplicates_pass.cpp b/3rdparty/spirv-tools/source/opt/remove_duplicates_pass.cpp index 0a54d76ea..a37e9df9e 100644 --- a/3rdparty/spirv-tools/source/opt/remove_duplicates_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/remove_duplicates_pass.cpp @@ -12,49 +12,42 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "remove_duplicates_pass.h" - -#include +#include "source/opt/remove_duplicates_pass.h" #include +#include #include +#include #include #include #include -#include "decoration_manager.h" -#include "ir_context.h" -#include "opcode.h" -#include "reflect.h" +#include "source/opcode.h" +#include "source/opt/decoration_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/reflect.h" namespace spvtools { namespace opt { -using ir::Instruction; -using ir::Module; -using ir::Operand; -using opt::analysis::DecorationManager; -using opt::analysis::DefUseManager; - -Pass::Status RemoveDuplicatesPass::Process(ir::IRContext* ir_context) { - bool modified = RemoveDuplicateCapabilities(ir_context); - modified |= RemoveDuplicatesExtInstImports(ir_context); - modified |= RemoveDuplicateTypes(ir_context); - modified |= RemoveDuplicateDecorations(ir_context); +Pass::Status RemoveDuplicatesPass::Process() { + bool modified = RemoveDuplicateCapabilities(); + modified |= RemoveDuplicatesExtInstImports(); + modified |= RemoveDuplicateTypes(); + modified |= RemoveDuplicateDecorations(); return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } -bool RemoveDuplicatesPass::RemoveDuplicateCapabilities( - ir::IRContext* ir_context) const { +bool RemoveDuplicatesPass::RemoveDuplicateCapabilities() const { bool modified = false; - if (ir_context->capabilities().empty()) { + if (context()->capabilities().empty()) { return modified; } std::unordered_set capabilities; - for (auto* i = &*ir_context->capability_begin(); i;) { + for (auto* i = &*context()->capability_begin(); i;) { auto res = capabilities.insert(i->GetSingleWordOperand(0u)); if (res.second) { @@ -62,7 +55,7 @@ bool RemoveDuplicatesPass::RemoveDuplicateCapabilities( i = i->NextNode(); } else { // It's a duplicate, remove it. - i = ir_context->KillInst(i); + i = context()->KillInst(i); modified = true; } } @@ -70,16 +63,15 @@ bool RemoveDuplicatesPass::RemoveDuplicateCapabilities( return modified; } -bool RemoveDuplicatesPass::RemoveDuplicatesExtInstImports( - ir::IRContext* ir_context) const { +bool RemoveDuplicatesPass::RemoveDuplicatesExtInstImports() const { bool modified = false; - if (ir_context->ext_inst_imports().empty()) { + if (context()->ext_inst_imports().empty()) { return modified; } std::unordered_map ext_inst_imports; - for (auto* i = &*ir_context->ext_inst_import_begin(); i;) { + for (auto* i = &*context()->ext_inst_import_begin(); i;) { auto res = ext_inst_imports.emplace( reinterpret_cast(i->GetInOperand(0u).words.data()), i->result_id()); @@ -88,8 +80,8 @@ bool RemoveDuplicatesPass::RemoveDuplicatesExtInstImports( i = i->NextNode(); } else { // It's a duplicate, remove it. - ir_context->ReplaceAllUsesWith(i->result_id(), res.first->second); - i = ir_context->KillInst(i); + context()->ReplaceAllUsesWith(i->result_id(), res.first->second); + i = context()->KillInst(i); modified = true; } } @@ -97,17 +89,16 @@ bool RemoveDuplicatesPass::RemoveDuplicatesExtInstImports( return modified; } -bool RemoveDuplicatesPass::RemoveDuplicateTypes( - ir::IRContext* ir_context) const { +bool RemoveDuplicatesPass::RemoveDuplicateTypes() const { bool modified = false; - if (ir_context->types_values().empty()) { + if (context()->types_values().empty()) { return modified; } std::vector visited_types; std::vector to_delete; - for (auto* i = &*ir_context->types_values_begin(); i; i = i->NextNode()) { + for (auto* i = &*context()->types_values_begin(); i; i = i->NextNode()) { // We only care about types. if (!spvOpcodeGeneratesType((i->opcode())) && i->opcode() != SpvOpTypeForwardPointer) { @@ -119,7 +110,7 @@ bool RemoveDuplicatesPass::RemoveDuplicateTypes( // TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the // ResultIdTrie from unify_const_pass.cpp for this. for (auto j : visited_types) { - if (AreTypesEqual(*i, *j, ir_context)) { + if (AreTypesEqual(*i, *j, context())) { id_to_keep = j->result_id(); break; } @@ -130,15 +121,15 @@ bool RemoveDuplicatesPass::RemoveDuplicateTypes( visited_types.emplace_back(i); } else { // The same type has already been seen before, remove this one. - ir_context->KillNamesAndDecorates(i->result_id()); - ir_context->ReplaceAllUsesWith(i->result_id(), id_to_keep); + context()->KillNamesAndDecorates(i->result_id()); + context()->ReplaceAllUsesWith(i->result_id(), id_to_keep); modified = true; to_delete.emplace_back(i); } } for (auto i : to_delete) { - ir_context->KillInst(i); + context()->KillInst(i); } return modified; @@ -153,14 +144,13 @@ bool RemoveDuplicatesPass::RemoveDuplicateTypes( // OpGroupDecorate %1 %3 // OpGroupDecorate %2 %4 // group %2 could be removed. -bool RemoveDuplicatesPass::RemoveDuplicateDecorations( - ir::IRContext* ir_context) const { +bool RemoveDuplicatesPass::RemoveDuplicateDecorations() const { bool modified = false; std::vector visited_decorations; - opt::analysis::DecorationManager decoration_manager(ir_context->module()); - for (auto* i = &*ir_context->annotation_begin(); i;) { + analysis::DecorationManager decoration_manager(context()->module()); + for (auto* i = &*context()->annotation_begin(); i;) { // Is the current decoration equal to one of the decorations we have aready // visited? bool already_visited = false; @@ -180,7 +170,7 @@ bool RemoveDuplicatesPass::RemoveDuplicateDecorations( } else { // The same decoration has already been seen before, remove this one. modified = true; - i = ir_context->KillInst(i); + i = context()->KillInst(i); } } @@ -189,9 +179,9 @@ bool RemoveDuplicatesPass::RemoveDuplicateDecorations( bool RemoveDuplicatesPass::AreTypesEqual(const Instruction& inst1, const Instruction& inst2, - ir::IRContext* context) { + IRContext* context) { if (inst1.opcode() != inst2.opcode()) return false; - if (!ir::IsTypeInst(inst1.opcode())) return false; + if (!IsTypeInst(inst1.opcode())) return false; const analysis::Type* type1 = context->get_type_mgr()->GetType(inst1.result_id()); diff --git a/3rdparty/spirv-tools/source/opt/remove_duplicates_pass.h b/3rdparty/spirv-tools/source/opt/remove_duplicates_pass.h index d766f6733..8554a987d 100644 --- a/3rdparty/spirv-tools/source/opt/remove_duplicates_pass.h +++ b/3rdparty/spirv-tools/source/opt/remove_duplicates_pass.h @@ -12,56 +12,56 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_REMOVE_DUPLICATES_PASS_H_ -#define LIBSPIRV_OPT_REMOVE_DUPLICATES_PASS_H_ +#ifndef SOURCE_OPT_REMOVE_DUPLICATES_PASS_H_ +#define SOURCE_OPT_REMOVE_DUPLICATES_PASS_H_ #include +#include -#include "decoration_manager.h" -#include "def_use_manager.h" -#include "ir_context.h" -#include "module.h" -#include "pass.h" +#include "source/opt/decoration_manager.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { using IdDecorationsList = - std::unordered_map>; + std::unordered_map>; // See optimizer.hpp for documentation. class RemoveDuplicatesPass : public Pass { public: const char* name() const override { return "remove-duplicates"; } - Status Process(ir::IRContext*) override; + Status Process() override; + // TODO(pierremoreau): Move this function somewhere else (e.g. pass.h or // within the type manager) // Returns whether two types are equal, and have the same decorations. - static bool AreTypesEqual(const ir::Instruction& inst1, - const ir::Instruction& inst2, - ir::IRContext* context); + static bool AreTypesEqual(const Instruction& inst1, const Instruction& inst2, + IRContext* context); private: - // Remove duplicate capabilities from the module attached to |ir_context|. + // Remove duplicate capabilities from the module // // Returns true if the module was modified, false otherwise. - bool RemoveDuplicateCapabilities(ir::IRContext* ir_context) const; - // Remove duplicate extended instruction imports from the module attached to - // |ir_context|. + bool RemoveDuplicateCapabilities() const; + // Remove duplicate extended instruction imports from the module // // Returns true if the module was modified, false otherwise. - bool RemoveDuplicatesExtInstImports(ir::IRContext* ir_context) const; - // Remove duplicate types from the module attached to |ir_context|. + bool RemoveDuplicatesExtInstImports() const; + // Remove duplicate types from the module // // Returns true if the module was modified, false otherwise. - bool RemoveDuplicateTypes(ir::IRContext* ir_context) const; - // Remove duplicate decorations from the module attached to |ir_context|. + bool RemoveDuplicateTypes() const; + // Remove duplicate decorations from the module // // Returns true if the module was modified, false otherwise. - bool RemoveDuplicateDecorations(ir::IRContext* ir_context) const; + bool RemoveDuplicateDecorations() const; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_REMOVE_DUPLICATES_PASS_H_ +#endif // SOURCE_OPT_REMOVE_DUPLICATES_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/replace_invalid_opc.cpp b/3rdparty/spirv-tools/source/opt/replace_invalid_opc.cpp index a025c3ce9..4e0f24f46 100644 --- a/3rdparty/spirv-tools/source/opt/replace_invalid_opc.cpp +++ b/3rdparty/spirv-tools/source/opt/replace_invalid_opc.cpp @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "replace_invalid_opc.h" +#include "source/opt/replace_invalid_opc.h" #include +#include namespace spvtools { namespace opt { -Pass::Status ReplaceInvalidOpcodePass::Process(ir::IRContext* c) { - InitializeProcessing(c); +Pass::Status ReplaceInvalidOpcodePass::Process() { bool modified = false; if (context()->get_feature_mgr()->HasCapability(SpvCapabilityLinkage)) { @@ -38,7 +38,7 @@ Pass::Status ReplaceInvalidOpcodePass::Process(ir::IRContext* c) { return Status::SuccessWithoutChange; } - for (ir::Function& func : *get_module()) { + for (Function& func : *get_module()) { modified |= RewriteFunction(&func, execution_model); } return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); @@ -47,7 +47,7 @@ Pass::Status ReplaceInvalidOpcodePass::Process(ir::IRContext* c) { SpvExecutionModel ReplaceInvalidOpcodePass::GetExecutionModel() { SpvExecutionModel result = SpvExecutionModelMax; bool first = true; - for (ir::Instruction& entry_point : get_module()->entry_points()) { + for (Instruction& entry_point : get_module()->entry_points()) { if (first) { result = static_cast(entry_point.GetSingleWordInOperand(0)); @@ -64,12 +64,12 @@ SpvExecutionModel ReplaceInvalidOpcodePass::GetExecutionModel() { return result; } -bool ReplaceInvalidOpcodePass::RewriteFunction(ir::Function* function, +bool ReplaceInvalidOpcodePass::RewriteFunction(Function* function, SpvExecutionModel model) { bool modified = false; - ir::Instruction* last_line_dbg_inst = nullptr; + Instruction* last_line_dbg_inst = nullptr; function->ForEachInst( - [model, &modified, &last_line_dbg_inst, this](ir::Instruction* inst) { + [model, &modified, &last_line_dbg_inst, this](Instruction* inst) { // Track the debug information so we can have a meaningful message. if (inst->opcode() == SpvOpLabel || inst->opcode() == SpvOpNoLine) { last_line_dbg_inst = nullptr; @@ -100,7 +100,7 @@ bool ReplaceInvalidOpcodePass::RewriteFunction(ir::Function* function, ReplaceInstruction(inst, nullptr, 0, 0); } else { // Get the name of the source file. - ir::Instruction* file_name = context()->get_def_use_mgr()->GetDef( + Instruction* file_name = context()->get_def_use_mgr()->GetDef( last_line_dbg_inst->GetSingleWordInOperand(0)); const char* source = reinterpret_cast( &file_name->GetInOperand(0).words[0]); @@ -120,7 +120,7 @@ bool ReplaceInvalidOpcodePass::RewriteFunction(ir::Function* function, } bool ReplaceInvalidOpcodePass::IsFragmentShaderOnlyInstruction( - ir::Instruction* inst) { + Instruction* inst) { switch (inst->opcode()) { case SpvOpDPdx: case SpvOpDPdy: @@ -147,7 +147,7 @@ bool ReplaceInvalidOpcodePass::IsFragmentShaderOnlyInstruction( } } -void ReplaceInvalidOpcodePass::ReplaceInstruction(ir::Instruction* inst, +void ReplaceInvalidOpcodePass::ReplaceInstruction(Instruction* inst, const char* source, uint32_t line_number, uint32_t column_number) { @@ -172,7 +172,7 @@ uint32_t ReplaceInvalidOpcodePass::GetSpecialConstant(uint32_t type_id) { analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); analysis::TypeManager* type_mgr = context()->get_type_mgr(); - ir::Instruction* type = context()->get_def_use_mgr()->GetDef(type_id); + Instruction* type = context()->get_def_use_mgr()->GetDef(type_id); if (type->opcode() == SpvOpTypeVector) { uint32_t component_const = GetSpecialConstant(type->GetSingleWordInOperand(0)); diff --git a/3rdparty/spirv-tools/source/opt/replace_invalid_opc.h b/3rdparty/spirv-tools/source/opt/replace_invalid_opc.h index e661fcec0..426bcac5e 100644 --- a/3rdparty/spirv-tools/source/opt/replace_invalid_opc.h +++ b/3rdparty/spirv-tools/source/opt/replace_invalid_opc.h @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_REPLACE_INVALID_OPC_H_ -#define LIBSPIRV_OPT_REPLACE_INVALID_OPC_H_ +#ifndef SOURCE_OPT_REPLACE_INVALID_OPC_H_ +#define SOURCE_OPT_REPLACE_INVALID_OPC_H_ -#include "pass.h" +#include + +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -26,8 +28,8 @@ namespace opt { // value, the instruction will simply be deleted. class ReplaceInvalidOpcodePass : public Pass { public: - const char* name() const override { return "replace-invalid-opcodes"; } - Status Process(ir::IRContext*) override; + const char* name() const override { return "replace-invalid-opcode"; } + Status Process() override; private: // Returns the execution model that is used by every entry point in the @@ -38,16 +40,16 @@ class ReplaceInvalidOpcodePass : public Pass { // Replaces all instructions in |function| that are invalid with execution // model |mode|, but valid for another shader model, with a special constant // value. See |GetSpecialConstant|. - bool RewriteFunction(ir::Function* function, SpvExecutionModel mode); + bool RewriteFunction(Function* function, SpvExecutionModel mode); // Returns true if |inst| is valid for fragment shaders only. - bool IsFragmentShaderOnlyInstruction(ir::Instruction* inst); + bool IsFragmentShaderOnlyInstruction(Instruction* inst); // Replaces all uses of the result of |inst|, if there is one, with the id of // a special constant. Then |inst| is killed. |inst| cannot be a block // terminator because the basic block will then become invalid. |inst| is no // longer valid after calling this function. - void ReplaceInstruction(ir::Instruction* inst, const char* source, + void ReplaceInstruction(Instruction* inst, const char* source, uint32_t line_number, uint32_t column_number); // Returns the id of a constant with type |type_id|. The type must be an @@ -62,4 +64,4 @@ class ReplaceInvalidOpcodePass : public Pass { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_REPLACE_INVALID_OPC_H_ +#endif // SOURCE_OPT_REPLACE_INVALID_OPC_H_ diff --git a/3rdparty/spirv-tools/source/opt/scalar_analysis.cpp b/3rdparty/spirv-tools/source/opt/scalar_analysis.cpp index ccdb66c82..38555e649 100644 --- a/3rdparty/spirv-tools/source/opt/scalar_analysis.cpp +++ b/3rdparty/spirv-tools/source/opt/scalar_analysis.cpp @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "opt/scalar_analysis.h" +#include "source/opt/scalar_analysis.h" #include #include #include #include -#include "opt/ir_context.h" +#include "source/opt/ir_context.h" // Transforms a given scalar operation instruction into a DAG representation. // @@ -48,8 +48,8 @@ namespace opt { uint32_t SENode::NumberOfNodes = 0; -ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(ir::IRContext* context) - : context_(context) { +ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(IRContext* context) + : context_(context), pretend_equal_{} { // Create and cached the CantComputeNode. cached_cant_compute_ = GetCachedOrAdd(std::unique_ptr(new SECantCompute(this))); @@ -73,14 +73,22 @@ SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) { } SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression( - const ir::Loop* loop, SENode* offset, SENode* coefficient) { + const Loop* loop, SENode* offset, SENode* coefficient) { assert(loop && "Recurrent add expressions must have a valid loop."); // If operands are can't compute then the whole graph is can't compute. if (offset->IsCantCompute() || coefficient->IsCantCompute()) return CreateCantComputeNode(); - std::unique_ptr phi_node{new SERecurrentNode(this, loop)}; + const Loop* loop_to_use = nullptr; + if (pretend_equal_[loop]) { + loop_to_use = pretend_equal_[loop]; + } else { + loop_to_use = loop; + } + + std::unique_ptr phi_node{ + new SERecurrentNode(this, loop_to_use)}; phi_node->AddOffset(offset); phi_node->AddCoefficient(coefficient); @@ -88,10 +96,10 @@ SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression( } SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp( - const ir::Instruction* multiply) { + const Instruction* multiply) { assert(multiply->opcode() == SpvOp::SpvOpIMul && "Multiply node did not come from a multiply instruction"); - opt::analysis::DefUseManager* def_use = context_->get_def_use_mgr(); + analysis::DefUseManager* def_use = context_->get_def_use_mgr(); SENode* op1 = AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0))); @@ -154,8 +162,7 @@ SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1, return GetCachedOrAdd(std::move(add_node)); } -SENode* ScalarEvolutionAnalysis::AnalyzeInstruction( - const ir::Instruction* inst) { +SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(const Instruction* inst) { auto itr = recurrent_node_map_.find(inst); if (itr != recurrent_node_map_.end()) return itr->second; @@ -188,7 +195,7 @@ SENode* ScalarEvolutionAnalysis::AnalyzeInstruction( return output; } -SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const ir::Instruction* inst) { +SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const Instruction* inst) { if (inst->opcode() == SpvOp::SpvOpConstantNull) return CreateConstant(0); assert(inst->opcode() == SpvOp::SpvOpConstant); @@ -196,12 +203,12 @@ SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const ir::Instruction* inst) { int64_t value = 0; // Look up the instruction in the constant manager. - const opt::analysis::Constant* constant = + const analysis::Constant* constant = context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id()); if (!constant) return CreateCantComputeNode(); - const opt::analysis::IntConstant* int_constant = constant->AsIntConstant(); + const analysis::IntConstant* int_constant = constant->AsIntConstant(); // Exit out if it is a 64 bit integer. if (!int_constant || int_constant->words().size() != 1) @@ -218,12 +225,12 @@ SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const ir::Instruction* inst) { // Handles both addition and subtraction. If the |sub| flag is set then the // addition will be op1+(-op2) otherwise op1+op2. -SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const ir::Instruction* inst) { +SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const Instruction* inst) { assert((inst->opcode() == SpvOp::SpvOpIAdd || inst->opcode() == SpvOp::SpvOpISub) && "Add node must be created from a OpIAdd or OpISub instruction"); - opt::analysis::DefUseManager* def_use = context_->get_def_use_mgr(); + analysis::DefUseManager* def_use = context_->get_def_use_mgr(); SENode* op1 = AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0))); @@ -239,30 +246,29 @@ SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const ir::Instruction* inst) { return CreateAddNode(op1, op2); } -SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction( - const ir::Instruction* phi) { +SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(const Instruction* phi) { // The phi should only have two incoming value pairs. if (phi->NumInOperands() != 4) { return CreateCantComputeNode(); } - opt::analysis::DefUseManager* def_use = context_->get_def_use_mgr(); + analysis::DefUseManager* def_use = context_->get_def_use_mgr(); // Get the basic block this instruction belongs to. - ir::BasicBlock* basic_block = - context_->get_instr_block(const_cast(phi)); + BasicBlock* basic_block = + context_->get_instr_block(const_cast(phi)); // And then the function that the basic blocks belongs to. - ir::Function* function = basic_block->GetParent(); + Function* function = basic_block->GetParent(); // Use the function to get the loop descriptor. - ir::LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function); + LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function); // We only handle phis in loops at the moment. if (!loop_descriptor) return CreateCantComputeNode(); // Get the innermost loop which this block belongs to. - ir::Loop* loop = (*loop_descriptor)[basic_block->id()]; + Loop* loop = (*loop_descriptor)[basic_block->id()]; // If the loop doesn't exist or doesn't have a preheader or latch block, exit // out. @@ -270,7 +276,14 @@ SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction( loop->GetHeaderBlock() != basic_block) return recurrent_node_map_[phi] = CreateCantComputeNode(); - std::unique_ptr phi_node{new SERecurrentNode(this, loop)}; + const Loop* loop_to_use = nullptr; + if (pretend_equal_[loop]) { + loop_to_use = pretend_equal_[loop]; + } else { + loop_to_use = loop; + } + std::unique_ptr phi_node{ + new SERecurrentNode(this, loop_to_use)}; // We add the node to this map to allow it to be returned before the node is // fully built. This is needed as the subsequent call to AnalyzeInstruction @@ -283,7 +296,7 @@ SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction( uint32_t value_id = phi->GetSingleWordInOperand(i); uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1); - ir::Instruction* value_inst = def_use->GetDef(value_id); + Instruction* value_inst = def_use->GetDef(value_id); SENode* value_node = AnalyzeInstruction(value_inst); // If any operand is CantCompute then the whole graph is CantCompute. @@ -337,7 +350,7 @@ SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction( } SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode( - const ir::Instruction* inst) { + const Instruction* inst) { std::unique_ptr load_node{ new SEValueUnknown(this, inst->result_id())}; return GetCachedOrAdd(std::move(load_node)); @@ -360,11 +373,11 @@ SENode* ScalarEvolutionAnalysis::GetCachedOrAdd( return raw_ptr_to_node; } -bool ScalarEvolutionAnalysis::IsLoopInvariant(const ir::Loop* loop, +bool ScalarEvolutionAnalysis::IsLoopInvariant(const Loop* loop, const SENode* node) const { for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) { if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) { - const ir::BasicBlock* header = rec->GetLoop()->GetHeaderBlock(); + const BasicBlock* header = rec->GetLoop()->GetHeaderBlock(); // If the loop which the recurrent expression belongs to is either |loop // or a nested loop inside |loop| then we assume it is variant. @@ -382,7 +395,7 @@ bool ScalarEvolutionAnalysis::IsLoopInvariant(const ir::Loop* loop, } SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm( - SENode* node, const ir::Loop* loop) { + SENode* node, const Loop* loop) { // Traverse the DAG to find the recurrent expression belonging to |loop|. for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) { SERecurrentNode* rec = itr->AsSERecurrentNode(); @@ -419,7 +432,7 @@ SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent, // Rebuild the |node| eliminating, if it exists, the recurrent term which // belongs to the |loop|. SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm( - SENode* node, const ir::Loop* loop) { + SENode* node, const Loop* loop) { // If the node is already a recurrent expression belonging to loop then just // return the offset. SERecurrentNode* recurrent = node->AsSERecurrentNode(); @@ -452,8 +465,8 @@ SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm( // Return the recurrent term belonging to |loop| if it appears in the graph // starting at |node| or null if it doesn't. -SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm( - SENode* node, const ir::Loop* loop) { +SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(SENode* node, + const Loop* loop) { for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) { SERecurrentNode* rec = itr->AsSERecurrentNode(); if (rec && rec->GetLoop() == loop) { @@ -634,5 +647,342 @@ void SENode::DumpDot(std::ostream& out, bool recurse) const { } } +namespace { +class IsGreaterThanZero { + public: + explicit IsGreaterThanZero(IRContext* context) : context_(context) {} + + // Determine if the value of |node| is always strictly greater than zero if + // |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is + // true. It returns true is the evaluation was able to conclude something, in + // which case the result is stored in |result|. + // The algorithm work by going through all the nodes and determine the + // sign of each of them. + bool Eval(const SENode* node, bool or_equal_zero, bool* result) { + *result = false; + switch (Visit(node)) { + case Signedness::kPositiveOrNegative: { + return false; + } + case Signedness::kStrictlyNegative: { + *result = false; + break; + } + case Signedness::kNegative: { + if (!or_equal_zero) { + return false; + } + *result = false; + break; + } + case Signedness::kStrictlyPositive: { + *result = true; + break; + } + case Signedness::kPositive: { + if (!or_equal_zero) { + return false; + } + *result = true; + break; + } + } + return true; + } + + private: + enum class Signedness { + kPositiveOrNegative, // Yield a value positive or negative. + kStrictlyNegative, // Yield a value strictly less than 0. + kNegative, // Yield a value less or equal to 0. + kStrictlyPositive, // Yield a value strictly greater than 0. + kPositive // Yield a value greater or equal to 0. + }; + + // Combine the signedness according to arithmetic rules of a given operator. + using Combiner = std::function; + + // Returns a functor to interpret the signedness of 2 expressions as if they + // were added. + Combiner GetAddCombiner() const { + return [](Signedness lhs, Signedness rhs) { + switch (lhs) { + case Signedness::kPositiveOrNegative: + break; + case Signedness::kStrictlyNegative: + if (rhs == Signedness::kStrictlyNegative || + rhs == Signedness::kNegative) + return lhs; + break; + case Signedness::kNegative: { + if (rhs == Signedness::kStrictlyNegative) + return Signedness::kStrictlyNegative; + if (rhs == Signedness::kNegative) return Signedness::kNegative; + break; + } + case Signedness::kStrictlyPositive: { + if (rhs == Signedness::kStrictlyPositive || + rhs == Signedness::kPositive) { + return Signedness::kStrictlyPositive; + } + break; + } + case Signedness::kPositive: { + if (rhs == Signedness::kStrictlyPositive) + return Signedness::kStrictlyPositive; + if (rhs == Signedness::kPositive) return Signedness::kPositive; + break; + } + } + return Signedness::kPositiveOrNegative; + }; + } + + // Returns a functor to interpret the signedness of 2 expressions as if they + // were multiplied. + Combiner GetMulCombiner() const { + return [](Signedness lhs, Signedness rhs) { + switch (lhs) { + case Signedness::kPositiveOrNegative: + break; + case Signedness::kStrictlyNegative: { + switch (rhs) { + case Signedness::kPositiveOrNegative: { + break; + } + case Signedness::kStrictlyNegative: { + return Signedness::kStrictlyPositive; + } + case Signedness::kNegative: { + return Signedness::kPositive; + } + case Signedness::kStrictlyPositive: { + return Signedness::kStrictlyNegative; + } + case Signedness::kPositive: { + return Signedness::kNegative; + } + } + break; + } + case Signedness::kNegative: { + switch (rhs) { + case Signedness::kPositiveOrNegative: { + break; + } + case Signedness::kStrictlyNegative: + case Signedness::kNegative: { + return Signedness::kPositive; + } + case Signedness::kStrictlyPositive: + case Signedness::kPositive: { + return Signedness::kNegative; + } + } + break; + } + case Signedness::kStrictlyPositive: { + return rhs; + } + case Signedness::kPositive: { + switch (rhs) { + case Signedness::kPositiveOrNegative: { + break; + } + case Signedness::kStrictlyNegative: + case Signedness::kNegative: { + return Signedness::kNegative; + } + case Signedness::kStrictlyPositive: + case Signedness::kPositive: { + return Signedness::kPositive; + } + } + break; + } + } + return Signedness::kPositiveOrNegative; + }; + } + + Signedness Visit(const SENode* node) { + switch (node->GetType()) { + case SENode::Constant: + return Visit(node->AsSEConstantNode()); + break; + case SENode::RecurrentAddExpr: + return Visit(node->AsSERecurrentNode()); + break; + case SENode::Negative: + return Visit(node->AsSENegative()); + break; + case SENode::CanNotCompute: + return Visit(node->AsSECantCompute()); + break; + case SENode::ValueUnknown: + return Visit(node->AsSEValueUnknown()); + break; + case SENode::Add: + return VisitExpr(node, GetAddCombiner()); + break; + case SENode::Multiply: + return VisitExpr(node, GetMulCombiner()); + break; + } + return Signedness::kPositiveOrNegative; + } + + // Returns the signedness of a constant |node|. + Signedness Visit(const SEConstantNode* node) { + if (0 == node->FoldToSingleValue()) return Signedness::kPositive; + if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive; + if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative; + return Signedness::kPositiveOrNegative; + } + + // Returns the signedness of an unknown |node| based on its type. + Signedness Visit(const SEValueUnknown* node) { + Instruction* insn = context_->get_def_use_mgr()->GetDef(node->ResultId()); + analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id()); + assert(type && "Can't retrieve a type for the instruction"); + analysis::Integer* int_type = type->AsInteger(); + assert(type && "Can't retrieve an integer type for the instruction"); + return int_type->IsSigned() ? Signedness::kPositiveOrNegative + : Signedness::kPositive; + } + + // Returns the signedness of a recurring expression. + Signedness Visit(const SERecurrentNode* node) { + Signedness coeff_sign = Visit(node->GetCoefficient()); + // SERecurrentNode represent an affine expression in the range [0, + // loop_bound], so the result cannot be strictly positive or negative. + switch (coeff_sign) { + default: + break; + case Signedness::kStrictlyNegative: + coeff_sign = Signedness::kNegative; + break; + case Signedness::kStrictlyPositive: + coeff_sign = Signedness::kPositive; + break; + } + return GetAddCombiner()(coeff_sign, Visit(node->GetOffset())); + } + + // Returns the signedness of a negation |node|. + Signedness Visit(const SENegative* node) { + switch (Visit(*node->begin())) { + case Signedness::kPositiveOrNegative: { + return Signedness::kPositiveOrNegative; + } + case Signedness::kStrictlyNegative: { + return Signedness::kStrictlyPositive; + } + case Signedness::kNegative: { + return Signedness::kPositive; + } + case Signedness::kStrictlyPositive: { + return Signedness::kStrictlyNegative; + } + case Signedness::kPositive: { + return Signedness::kNegative; + } + } + return Signedness::kPositiveOrNegative; + } + + Signedness Visit(const SECantCompute*) { + return Signedness::kPositiveOrNegative; + } + + // Returns the signedness of a binary expression by using the combiner + // |reduce|. + Signedness VisitExpr( + const SENode* node, + std::function reduce) { + Signedness result = Visit(*node->begin()); + for (const SENode* operand : make_range(++node->begin(), node->end())) { + if (result == Signedness::kPositiveOrNegative) { + return Signedness::kPositiveOrNegative; + } + result = reduce(result, Visit(operand)); + } + return result; + } + + IRContext* context_; +}; +} // namespace + +bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node, + bool* is_gt_zero) const { + return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero); +} + +bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero( + SENode* node, bool* is_ge_zero) const { + return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero); +} + +namespace { + +// Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z), +// if |node| is not in the chain, returns the original chain. +static SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul, + const SENode* node) { + SENode* lhs = mul->GetChildren()[0]; + SENode* rhs = mul->GetChildren()[1]; + if (lhs == node) { + return rhs; + } + if (rhs == node) { + return lhs; + } + if (lhs->AsSEMultiplyNode()) { + SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node); + if (res != lhs) + return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs); + } + if (rhs->AsSEMultiplyNode()) { + SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node); + if (res != rhs) + return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs); + } + + return mul; +} +} // namespace + +std::pair SExpression::operator/( + SExpression rhs_wrapper) const { + SENode* lhs = node_; + SENode* rhs = rhs_wrapper.node_; + // Check for division by 0. + if (rhs->AsSEConstantNode() && + !rhs->AsSEConstantNode()->FoldToSingleValue()) { + return {scev_->CreateCantComputeNode(), 0}; + } + + // Trivial case. + if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) { + int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue(); + int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue(); + return {scev_->CreateConstant(lhs_value / rhs_value), + lhs_value % rhs_value}; + } + + // look for a "c U / U" pattern. + if (lhs->AsSEMultiplyNode()) { + assert(lhs->GetChildren().size() == 2 && + "More than 2 operand for a multiply node."); + SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs); + if (res != lhs) { + return {res, 0}; + } + } + + return {scev_->CreateCantComputeNode(), 0}; +} + } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/scalar_analysis.h b/3rdparty/spirv-tools/source/opt/scalar_analysis.h index 71cc424b5..fb6d631f5 100644 --- a/3rdparty/spirv-tools/source/opt/scalar_analysis.h +++ b/3rdparty/spirv-tools/source/opt/scalar_analysis.h @@ -20,19 +20,18 @@ #include #include #include +#include #include -#include "opt/basic_block.h" -#include "opt/instruction.h" -#include "opt/scalar_analysis_nodes.h" +#include "source/opt/basic_block.h" +#include "source/opt/instruction.h" +#include "source/opt/scalar_analysis_nodes.h" namespace spvtools { -namespace ir { +namespace opt { + class IRContext; class Loop; -} // namespace ir - -namespace opt { // Manager for the Scalar Evolution analysis. Creates and maintains a DAG of // scalar operations generated from analysing the use def graph from incoming @@ -42,7 +41,7 @@ namespace opt { // usable form with SimplifyExpression. class ScalarEvolutionAnalysis { public: - explicit ScalarEvolutionAnalysis(ir::IRContext* context); + explicit ScalarEvolutionAnalysis(IRContext* context); // Create a unary negative node on |operand|. SENode* CreateNegation(SENode* operand); @@ -63,18 +62,18 @@ class ScalarEvolutionAnalysis { SENode* CreateConstant(int64_t integer); // Create a value unknown node, such as a load. - SENode* CreateValueUnknownNode(const ir::Instruction* inst); + SENode* CreateValueUnknownNode(const Instruction* inst); // Create a CantComputeNode. Used to exit out of analysis. SENode* CreateCantComputeNode(); // Create a new recurrent node with |offset| and |coefficient|, with respect // to |loop|. - SENode* CreateRecurrentExpression(const ir::Loop* loop, SENode* offset, + SENode* CreateRecurrentExpression(const Loop* loop, SENode* offset, SENode* coefficient); // Construct the DAG by traversing use def chain of |inst|. - SENode* AnalyzeInstruction(const ir::Instruction* inst); + SENode* AnalyzeInstruction(const Instruction* inst); // Simplify the |node| by grouping like terms or if contains a recurrent // expression, rewrite the graph so the whole DAG (from |node| down) is in @@ -93,43 +92,60 @@ class ScalarEvolutionAnalysis { SENode* GetCachedOrAdd(std::unique_ptr prospective_node); // Checks that the graph starting from |node| is invariant to the |loop|. - bool IsLoopInvariant(const ir::Loop* loop, const SENode* node) const; + bool IsLoopInvariant(const Loop* loop, const SENode* node) const; + + // Sets |is_gt_zero| to true if |node| represent a value always strictly + // greater than 0. The result of |is_gt_zero| is valid only if the function + // returns true. + bool IsAlwaysGreaterThanZero(SENode* node, bool* is_gt_zero) const; + + // Sets |is_ge_zero| to true if |node| represent a value greater or equals to + // 0. The result of |is_ge_zero| is valid only if the function returns true. + bool IsAlwaysGreaterOrEqualToZero(SENode* node, bool* is_ge_zero) const; // Find the recurrent term belonging to |loop| in the graph starting from // |node| and return the coefficient of that recurrent term. Constant zero // will be returned if no recurrent could be found. |node| should be in // simplest form. - SENode* GetCoefficientFromRecurrentTerm(SENode* node, const ir::Loop* loop); + SENode* GetCoefficientFromRecurrentTerm(SENode* node, const Loop* loop); // Return a rebuilt graph starting from |node| with the recurrent expression // belonging to |loop| being zeroed out. Returned node will be simplified. - SENode* BuildGraphWithoutRecurrentTerm(SENode* node, const ir::Loop* loop); + SENode* BuildGraphWithoutRecurrentTerm(SENode* node, const Loop* loop); // Return the recurrent term belonging to |loop| if it appears in the graph // starting at |node| or null if it doesn't. - SERecurrentNode* GetRecurrentTerm(SENode* node, const ir::Loop* loop); + SERecurrentNode* GetRecurrentTerm(SENode* node, const Loop* loop); SENode* UpdateChildNode(SENode* parent, SENode* child, SENode* new_child); + // The loops in |loop_pair| will be considered the same when constructing + // SERecurrentNode objects. This enables analysing dependencies that will be + // created during loop fusion. + void AddLoopsToPretendAreTheSame( + const std::pair& loop_pair) { + pretend_equal_[std::get<1>(loop_pair)] = std::get<0>(loop_pair); + } + private: - SENode* AnalyzeConstant(const ir::Instruction* inst); + SENode* AnalyzeConstant(const Instruction* inst); // Handles both addition and subtraction. If the |instruction| is OpISub // then the resulting node will be op1+(-op2) otherwise if it is OpIAdd then // the result will be op1+op2. |instruction| must be OpIAdd or OpISub. - SENode* AnalyzeAddOp(const ir::Instruction* instruction); + SENode* AnalyzeAddOp(const Instruction* instruction); - SENode* AnalyzeMultiplyOp(const ir::Instruction* multiply); + SENode* AnalyzeMultiplyOp(const Instruction* multiply); - SENode* AnalyzePhiInstruction(const ir::Instruction* phi); + SENode* AnalyzePhiInstruction(const Instruction* phi); - ir::IRContext* context_; + IRContext* context_; // A map of instructions to SENodes. This is used to track recurrent // expressions as they are added when analyzing instructions. Recurrent // expressions come from phi nodes which by nature can include recursion so we // check if nodes have already been built when analyzing instructions. - std::map recurrent_node_map_; + std::map recurrent_node_map_; // On creation we create and cache the CantCompute node so we not need to // perform a needless create step. @@ -149,8 +165,150 @@ class ScalarEvolutionAnalysis { // managed by they set. std::unordered_set, SENodeHash, NodePointersEquality> node_cache_; + + // Loops that should be considered the same for performing analysis for loop + // fusion. + std::map pretend_equal_; }; +// Wrapping class to manipulate SENode pointer using + - * / operators. +class SExpression { + public: + // Implicit on purpose ! + SExpression(SENode* node) + : node_(node->GetParentAnalysis()->SimplifyExpression(node)), + scev_(node->GetParentAnalysis()) {} + + inline operator SENode*() const { return node_; } + inline SENode* operator->() const { return node_; } + const SENode& operator*() const { return *node_; } + + inline ScalarEvolutionAnalysis* GetScalarEvolutionAnalysis() const { + return scev_; + } + + inline SExpression operator+(SENode* rhs) const; + template ::value, int>::type = 0> + inline SExpression operator+(T integer) const; + inline SExpression operator+(SExpression rhs) const; + + inline SExpression operator-() const; + inline SExpression operator-(SENode* rhs) const; + template ::value, int>::type = 0> + inline SExpression operator-(T integer) const; + inline SExpression operator-(SExpression rhs) const; + + inline SExpression operator*(SENode* rhs) const; + template ::value, int>::type = 0> + inline SExpression operator*(T integer) const; + inline SExpression operator*(SExpression rhs) const; + + template ::value, int>::type = 0> + inline std::pair operator/(T integer) const; + // Try to perform a division. Returns the pair . If it fails to simplify it, the function returns a + // CanNotCompute node. + std::pair operator/(SExpression rhs) const; + + private: + SENode* node_; + ScalarEvolutionAnalysis* scev_; +}; + +inline SExpression SExpression::operator+(SENode* rhs) const { + return scev_->CreateAddNode(node_, rhs); +} + +template ::value, int>::type> +inline SExpression SExpression::operator+(T integer) const { + return *this + scev_->CreateConstant(integer); +} + +inline SExpression SExpression::operator+(SExpression rhs) const { + return *this + rhs.node_; +} + +inline SExpression SExpression::operator-() const { + return scev_->CreateNegation(node_); +} + +inline SExpression SExpression::operator-(SENode* rhs) const { + return *this + scev_->CreateNegation(rhs); +} + +template ::value, int>::type> +inline SExpression SExpression::operator-(T integer) const { + return *this - scev_->CreateConstant(integer); +} + +inline SExpression SExpression::operator-(SExpression rhs) const { + return *this - rhs.node_; +} + +inline SExpression SExpression::operator*(SENode* rhs) const { + return scev_->CreateMultiplyNode(node_, rhs); +} + +template ::value, int>::type> +inline SExpression SExpression::operator*(T integer) const { + return *this * scev_->CreateConstant(integer); +} + +inline SExpression SExpression::operator*(SExpression rhs) const { + return *this * rhs.node_; +} + +template ::value, int>::type> +inline std::pair SExpression::operator/(T integer) const { + return *this / scev_->CreateConstant(integer); +} + +template ::value, int>::type> +inline SExpression operator+(T lhs, SExpression rhs) { + return rhs + lhs; +} +inline SExpression operator+(SENode* lhs, SExpression rhs) { return rhs + lhs; } + +template ::value, int>::type> +inline SExpression operator-(T lhs, SExpression rhs) { + // NOLINTNEXTLINE(whitespace/braces) + return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} - + rhs; +} +inline SExpression operator-(SENode* lhs, SExpression rhs) { + // NOLINTNEXTLINE(whitespace/braces) + return SExpression{lhs} - rhs; +} + +template ::value, int>::type> +inline SExpression operator*(T lhs, SExpression rhs) { + return rhs * lhs; +} +inline SExpression operator*(SENode* lhs, SExpression rhs) { return rhs * lhs; } + +template ::value, int>::type> +inline std::pair operator/(T lhs, SExpression rhs) { + // NOLINTNEXTLINE(whitespace/braces) + return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} / + rhs; +} +inline std::pair operator/(SENode* lhs, SExpression rhs) { + // NOLINTNEXTLINE(whitespace/braces) + return SExpression{lhs} / rhs; +} + } // namespace opt } // namespace spvtools -#endif // SOURCE_OPT_SCALAR_ANALYSIS_H__ +#endif // SOURCE_OPT_SCALAR_ANALYSIS_H_ diff --git a/3rdparty/spirv-tools/source/opt/scalar_analysis_nodes.h b/3rdparty/spirv-tools/source/opt/scalar_analysis_nodes.h index 094ee8e24..450522ec3 100644 --- a/3rdparty/spirv-tools/source/opt/scalar_analysis_nodes.h +++ b/3rdparty/spirv-tools/source/opt/scalar_analysis_nodes.h @@ -19,15 +19,13 @@ #include #include #include -#include "opt/tree_iterator.h" + +#include "source/opt/tree_iterator.h" namespace spvtools { -namespace ir { -class Loop; -} // namespace ir - namespace opt { +class Loop; class ScalarEvolutionAnalysis; class SEConstantNode; class SERecurrentNode; @@ -56,7 +54,7 @@ class SENode { using ChildContainerType = std::vector; - explicit SENode(opt::ScalarEvolutionAnalysis* parent_analysis) + explicit SENode(ScalarEvolutionAnalysis* parent_analysis) : parent_analysis_(parent_analysis), unique_id_(++NumberOfNodes) {} virtual SENodeType GetType() const = 0; @@ -115,6 +113,42 @@ class SENode { const_iterator cbegin() { return children_.cbegin(); } const_iterator cend() { return children_.cend(); } + // Collect all the recurrent nodes in this SENode + std::vector CollectRecurrentNodes() { + std::vector recurrent_nodes{}; + + if (auto recurrent_node = AsSERecurrentNode()) { + recurrent_nodes.push_back(recurrent_node); + } + + for (auto child : GetChildren()) { + auto child_recurrent_nodes = child->CollectRecurrentNodes(); + recurrent_nodes.insert(recurrent_nodes.end(), + child_recurrent_nodes.begin(), + child_recurrent_nodes.end()); + } + + return recurrent_nodes; + } + + // Collect all the value unknown nodes in this SENode + std::vector CollectValueUnknownNodes() { + std::vector value_unknown_nodes{}; + + if (auto value_unknown_node = AsSEValueUnknown()) { + value_unknown_nodes.push_back(value_unknown_node); + } + + for (auto child : GetChildren()) { + auto child_value_unknown_nodes = child->CollectValueUnknownNodes(); + value_unknown_nodes.insert(value_unknown_nodes.end(), + child_value_unknown_nodes.begin(), + child_value_unknown_nodes.end()); + } + + return value_unknown_nodes; + } + // Iterator to iterate over the entire DAG. Even though we are using the tree // iterator it should still be safe to iterate over. However, nodes with // multiple parents will be visited multiple times, unlike in a tree. @@ -150,14 +184,14 @@ class SENode { #undef DeclareCastMethod // Get the analysis which has this node in its cache. - inline opt::ScalarEvolutionAnalysis* GetParentAnalysis() const { + inline ScalarEvolutionAnalysis* GetParentAnalysis() const { return parent_analysis_; } protected: ChildContainerType children_; - opt::ScalarEvolutionAnalysis* parent_analysis_; + ScalarEvolutionAnalysis* parent_analysis_; // The unique id of this node, assigned on creation by incrementing the static // node count. @@ -178,7 +212,7 @@ struct SENodeHash { // A node representing a constant integer. class SEConstantNode : public SENode { public: - SEConstantNode(opt::ScalarEvolutionAnalysis* parent_analysis, int64_t value) + SEConstantNode(ScalarEvolutionAnalysis* parent_analysis, int64_t value) : SENode(parent_analysis), literal_value_(value) {} SENodeType GetType() const final { return Constant; } @@ -204,8 +238,7 @@ class SEConstantNode : public SENode { // of zero and a coefficient of one. class SERecurrentNode : public SENode { public: - SERecurrentNode(opt::ScalarEvolutionAnalysis* parent_analysis, - const ir::Loop* loop) + SERecurrentNode(ScalarEvolutionAnalysis* parent_analysis, const Loop* loop) : SENode(parent_analysis), loop_(loop) {} SENodeType GetType() const final { return RecurrentAddExpr; } @@ -227,7 +260,7 @@ class SERecurrentNode : public SENode { inline SENode* GetOffset() { return offset_; } // Return the loop which this recurrent expression is recurring within. - const ir::Loop* GetLoop() const { return loop_; } + const Loop* GetLoop() const { return loop_; } SERecurrentNode* AsSERecurrentNode() override { return this; } const SERecurrentNode* AsSERecurrentNode() const override { return this; } @@ -235,13 +268,13 @@ class SERecurrentNode : public SENode { private: SENode* coefficient_; SENode* offset_; - const ir::Loop* loop_; + const Loop* loop_; }; // A node representing an addition operation between child nodes. class SEAddNode : public SENode { public: - explicit SEAddNode(opt::ScalarEvolutionAnalysis* parent_analysis) + explicit SEAddNode(ScalarEvolutionAnalysis* parent_analysis) : SENode(parent_analysis) {} SENodeType GetType() const final { return Add; } @@ -253,7 +286,7 @@ class SEAddNode : public SENode { // A node representing a multiply operation between child nodes. class SEMultiplyNode : public SENode { public: - explicit SEMultiplyNode(opt::ScalarEvolutionAnalysis* parent_analysis) + explicit SEMultiplyNode(ScalarEvolutionAnalysis* parent_analysis) : SENode(parent_analysis) {} SENodeType GetType() const final { return Multiply; } @@ -265,7 +298,7 @@ class SEMultiplyNode : public SENode { // A node representing a unary negative operation. class SENegative : public SENode { public: - explicit SENegative(opt::ScalarEvolutionAnalysis* parent_analysis) + explicit SENegative(ScalarEvolutionAnalysis* parent_analysis) : SENode(parent_analysis) {} SENodeType GetType() const final { return Negative; } @@ -281,8 +314,7 @@ class SEValueUnknown : public SENode { // SEValueUnknowns must come from an instruction |unique_id| is the unique id // of that instruction. This is so we cancompare value unknowns and have a // unique value unknown for each instruction. - SEValueUnknown(opt::ScalarEvolutionAnalysis* parent_analysis, - uint32_t result_id) + SEValueUnknown(ScalarEvolutionAnalysis* parent_analysis, uint32_t result_id) : SENode(parent_analysis), result_id_(result_id) {} SENodeType GetType() const final { return ValueUnknown; } @@ -299,7 +331,7 @@ class SEValueUnknown : public SENode { // A node which we cannot reason about at all. class SECantCompute : public SENode { public: - explicit SECantCompute(opt::ScalarEvolutionAnalysis* parent_analysis) + explicit SECantCompute(ScalarEvolutionAnalysis* parent_analysis) : SENode(parent_analysis) {} SENodeType GetType() const final { return CanNotCompute; } diff --git a/3rdparty/spirv-tools/source/opt/scalar_analysis_simplification.cpp b/3rdparty/spirv-tools/source/opt/scalar_analysis_simplification.cpp index 018896a46..52f2d6ad9 100644 --- a/3rdparty/spirv-tools/source/opt/scalar_analysis_simplification.cpp +++ b/3rdparty/spirv-tools/source/opt/scalar_analysis_simplification.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "opt/scalar_analysis.h" +#include "source/opt/scalar_analysis.h" #include #include @@ -375,7 +375,7 @@ SENode* SENodeSimplifyImpl::FoldRecurrentAddExpressions(SENode* root) { // A mapping of loops to the list of recurrent expressions which are with // respect to those loops. - std::map>> + std::map>> loops_to_recurrent{}; bool has_multiple_same_loop_recurrent_terms = false; @@ -389,7 +389,7 @@ SENode* SENodeSimplifyImpl::FoldRecurrentAddExpressions(SENode* root) { } if (child->GetType() == SENode::RecurrentAddExpr) { - const ir::Loop* loop = child->AsSERecurrentNode()->GetLoop(); + const Loop* loop = child->AsSERecurrentNode()->GetLoop(); SERecurrentNode* rec = child->AsSERecurrentNode(); if (loops_to_recurrent.find(loop) == loops_to_recurrent.end()) { @@ -408,7 +408,7 @@ SENode* SENodeSimplifyImpl::FoldRecurrentAddExpressions(SENode* root) { for (auto pair : loops_to_recurrent) { std::vector>& recurrent_expressions = pair.second; - const ir::Loop* loop = pair.first; + const Loop* loop = pair.first; std::unique_ptr new_coefficient{new SEAddNode(&analysis_)}; std::unique_ptr new_offset{new SEAddNode(&analysis_)}; diff --git a/3rdparty/spirv-tools/source/opt/scalar_replacement_pass.cpp b/3rdparty/spirv-tools/source/opt/scalar_replacement_pass.cpp index 8e96f70eb..d51dd8ef2 100644 --- a/3rdparty/spirv-tools/source/opt/scalar_replacement_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/scalar_replacement_pass.cpp @@ -12,26 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "scalar_replacement_pass.h" - -#include "enum_string_mapping.h" -#include "extensions.h" -#include "make_unique.h" -#include "reflect.h" -#include "types.h" +#include "source/opt/scalar_replacement_pass.h" +#include #include #include +#include + +#include "source/enum_string_mapping.h" +#include "source/extensions.h" +#include "source/opt/reflect.h" +#include "source/opt/types.h" +#include "source/util/make_unique.h" namespace spvtools { namespace opt { -// Heuristic aggregate element limit. -const uint32_t MAX_NUM_ELEMENTS = 100u; - -Pass::Status ScalarReplacementPass::Process(ir::IRContext* c) { - InitializeProcessing(c); - +Pass::Status ScalarReplacementPass::Process() { Status status = Status::SuccessWithoutChange; for (auto& f : *get_module()) { Status functionStatus = ProcessFunction(&f); @@ -44,15 +41,15 @@ Pass::Status ScalarReplacementPass::Process(ir::IRContext* c) { return status; } -Pass::Status ScalarReplacementPass::ProcessFunction(ir::Function* function) { - std::queue worklist; - ir::BasicBlock& entry = *function->begin(); +Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) { + std::queue worklist; + BasicBlock& entry = *function->begin(); for (auto iter = entry.begin(); iter != entry.end(); ++iter) { // Function storage class OpVariables must appear as the first instructions // of the entry block. if (iter->opcode() != SpvOpVariable) break; - ir::Instruction* varInst = &*iter; + Instruction* varInst = &*iter; if (CanReplaceVariable(varInst)) { worklist.push(varInst); } @@ -60,7 +57,7 @@ Pass::Status ScalarReplacementPass::ProcessFunction(ir::Function* function) { Status status = Status::SuccessWithoutChange; while (!worklist.empty()) { - ir::Instruction* varInst = worklist.front(); + Instruction* varInst = worklist.front(); worklist.pop(); if (!ReplaceVariable(varInst, &worklist)) @@ -73,15 +70,15 @@ Pass::Status ScalarReplacementPass::ProcessFunction(ir::Function* function) { } bool ScalarReplacementPass::ReplaceVariable( - ir::Instruction* inst, std::queue* worklist) { - std::vector replacements; + Instruction* inst, std::queue* worklist) { + std::vector replacements; CreateReplacementVariables(inst, &replacements); - std::vector dead; + std::vector dead; dead.push_back(inst); if (!get_def_use_mgr()->WhileEachUser( - inst, [this, &replacements, &dead](ir::Instruction* user) { - if (!ir::IsAnnotationInst(user->opcode())) { + inst, [this, &replacements, &dead](Instruction* user) { + if (!IsAnnotationInst(user->opcode())) { switch (user->opcode()) { case SpvOpLoad: ReplaceWholeLoad(user, replacements); @@ -110,18 +107,19 @@ bool ScalarReplacementPass::ReplaceVariable( // Clean up some dead code. while (!dead.empty()) { - ir::Instruction* toKill = dead.back(); + Instruction* toKill = dead.back(); dead.pop_back(); - context()->KillInst(toKill); } // Attempt to further scalarize. for (auto var : replacements) { - if (get_def_use_mgr()->NumUsers(var) == 0) { - context()->KillInst(var); - } else if (CanReplaceVariable(var)) { - worklist->push(var); + if (var->opcode() == SpvOpVariable) { + if (get_def_use_mgr()->NumUsers(var) == 0) { + context()->KillInst(var); + } else if (CanReplaceVariable(var)) { + worklist->push(var); + } } } @@ -129,25 +127,30 @@ bool ScalarReplacementPass::ReplaceVariable( } void ScalarReplacementPass::ReplaceWholeLoad( - ir::Instruction* load, const std::vector& replacements) { + Instruction* load, const std::vector& replacements) { // Replaces the load of the entire composite with a load from each replacement // variable followed by a composite construction. - ir::BasicBlock* block = context()->get_instr_block(load); - std::vector loads; + BasicBlock* block = context()->get_instr_block(load); + std::vector loads; loads.reserve(replacements.size()); - ir::BasicBlock::iterator where(load); + BasicBlock::iterator where(load); for (auto var : replacements) { // Create a load of each replacement variable. - ir::Instruction* type = GetStorageType(var); + if (var->opcode() != SpvOpVariable) { + loads.push_back(var); + continue; + } + + Instruction* type = GetStorageType(var); uint32_t loadId = TakeNextId(); - std::unique_ptr newLoad( - new ir::Instruction(context(), SpvOpLoad, type->result_id(), loadId, - std::initializer_list{ - {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); + std::unique_ptr newLoad( + new Instruction(context(), SpvOpLoad, type->result_id(), loadId, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); // Copy memory access attributes which start at index 1. Index 0 is the // pointer to load. for (uint32_t i = 1; i < load->NumInOperands(); ++i) { - ir::Operand copy(load->GetInOperand(i)); + Operand copy(load->GetInOperand(i)); newLoad->AddOperand(std::move(copy)); } where = where.InsertBefore(std::move(newLoad)); @@ -159,11 +162,11 @@ void ScalarReplacementPass::ReplaceWholeLoad( // Construct a new composite. uint32_t compositeId = TakeNextId(); where = load; - std::unique_ptr compositeConstruct(new ir::Instruction( + std::unique_ptr compositeConstruct(new Instruction( context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {})); for (auto l : loads) { - ir::Operand op(SPV_OPERAND_TYPE_ID, - std::initializer_list{l->result_id()}); + Operand op(SPV_OPERAND_TYPE_ID, + std::initializer_list{l->result_id()}); compositeConstruct->AddOperand(std::move(op)); } where = where.InsertBefore(std::move(compositeConstruct)); @@ -173,20 +176,25 @@ void ScalarReplacementPass::ReplaceWholeLoad( } void ScalarReplacementPass::ReplaceWholeStore( - ir::Instruction* store, const std::vector& replacements) { + Instruction* store, const std::vector& replacements) { // Replaces a store to the whole composite with a series of extract and stores // to each element. uint32_t storeInput = store->GetSingleWordInOperand(1u); - ir::BasicBlock* block = context()->get_instr_block(store); - ir::BasicBlock::iterator where(store); + BasicBlock* block = context()->get_instr_block(store); + BasicBlock::iterator where(store); uint32_t elementIndex = 0; for (auto var : replacements) { // Create the extract. - ir::Instruction* type = GetStorageType(var); + if (var->opcode() != SpvOpVariable) { + elementIndex++; + continue; + } + + Instruction* type = GetStorageType(var); uint32_t extractId = TakeNextId(); - std::unique_ptr extract(new ir::Instruction( + std::unique_ptr extract(new Instruction( context(), SpvOpCompositeExtract, type->result_id(), extractId, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_ID, {storeInput}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}})); auto iter = where.InsertBefore(std::move(extract)); @@ -194,15 +202,15 @@ void ScalarReplacementPass::ReplaceWholeStore( context()->set_instr_block(&*iter, block); // Create the store. - std::unique_ptr newStore( - new ir::Instruction(context(), SpvOpStore, 0, 0, - std::initializer_list{ - {SPV_OPERAND_TYPE_ID, {var->result_id()}}, - {SPV_OPERAND_TYPE_ID, {extractId}}})); + std::unique_ptr newStore( + new Instruction(context(), SpvOpStore, 0, 0, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {var->result_id()}}, + {SPV_OPERAND_TYPE_ID, {extractId}}})); // Copy memory access attributes which start at index 2. Index 0 is the // pointer and index 1 is the data. for (uint32_t i = 2; i < store->NumInOperands(); ++i) { - ir::Operand copy(store->GetInOperand(i)); + Operand copy(store->GetInOperand(i)); newStore->AddOperand(std::move(copy)); } iter = where.InsertBefore(std::move(newStore)); @@ -212,28 +220,28 @@ void ScalarReplacementPass::ReplaceWholeStore( } bool ScalarReplacementPass::ReplaceAccessChain( - ir::Instruction* chain, const std::vector& replacements) { + Instruction* chain, const std::vector& replacements) { // Replaces the access chain with either another access chain (with one fewer // indexes) or a direct use of the replacement variable. uint32_t indexId = chain->GetSingleWordInOperand(1u); - const ir::Instruction* index = get_def_use_mgr()->GetDef(indexId); + const Instruction* index = get_def_use_mgr()->GetDef(indexId); size_t indexValue = GetConstantInteger(index); if (indexValue > replacements.size()) { // Out of bounds access, this is illegal IR. return false; } else { - const ir::Instruction* var = replacements[indexValue]; + const Instruction* var = replacements[indexValue]; if (chain->NumInOperands() > 2) { // Replace input access chain with another access chain. - ir::BasicBlock::iterator chainIter(chain); + BasicBlock::iterator chainIter(chain); uint32_t replacementId = TakeNextId(); - std::unique_ptr replacementChain(new ir::Instruction( + std::unique_ptr replacementChain(new Instruction( context(), chain->opcode(), chain->type_id(), replacementId, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); // Add the remaining indexes. for (uint32_t i = 2; i < chain->NumInOperands(); ++i) { - ir::Operand copy(chain->GetInOperand(i)); + Operand copy(chain->GetInOperand(i)); replacementChain->AddOperand(std::move(copy)); } auto iter = chainIter.InsertBefore(std::move(replacementChain)); @@ -250,18 +258,34 @@ bool ScalarReplacementPass::ReplaceAccessChain( } void ScalarReplacementPass::CreateReplacementVariables( - ir::Instruction* inst, std::vector* replacements) { - ir::Instruction* type = GetStorageType(inst); + Instruction* inst, std::vector* replacements) { + Instruction* type = GetStorageType(inst); + + std::unique_ptr> components_used = + GetUsedComponents(inst); + uint32_t elem = 0; switch (type->opcode()) { case SpvOpTypeStruct: - type->ForEachInOperand([this, inst, &elem, replacements](uint32_t* id) { - CreateVariable(*id, inst, elem++, replacements); - }); + type->ForEachInOperand( + [this, inst, &elem, replacements, &components_used](uint32_t* id) { + if (!components_used || components_used->count(elem)) { + CreateVariable(*id, inst, elem, replacements); + } else { + replacements->push_back(CreateNullConstant(*id)); + } + elem++; + }); break; case SpvOpTypeArray: for (uint32_t i = 0; i != GetArrayLength(type); ++i) { - CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements); + if (!components_used || components_used->count(i)) { + CreateVariable(type->GetSingleWordInOperand(0u), inst, i, + replacements); + } else { + replacements->push_back( + CreateNullConstant(type->GetSingleWordInOperand(0u))); + } } break; @@ -281,8 +305,7 @@ void ScalarReplacementPass::CreateReplacementVariables( } void ScalarReplacementPass::TransferAnnotations( - const ir::Instruction* source, - std::vector* replacements) { + const Instruction* source, std::vector* replacements) { // Only transfer invariant and restrict decorations on the variable. There are // no type or member decorations that are necessary to transfer. for (auto inst : @@ -292,13 +315,13 @@ void ScalarReplacementPass::TransferAnnotations( if (decoration == SpvDecorationInvariant || decoration == SpvDecorationRestrict) { for (auto var : *replacements) { - std::unique_ptr annotation(new ir::Instruction( - context(), SpvOpDecorate, 0, 0, - std::initializer_list{ - {SPV_OPERAND_TYPE_ID, {var->result_id()}}, - {SPV_OPERAND_TYPE_DECORATION, {decoration}}})); + std::unique_ptr annotation( + new Instruction(context(), SpvOpDecorate, 0, 0, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {var->result_id()}}, + {SPV_OPERAND_TYPE_DECORATION, {decoration}}})); for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { - ir::Operand copy(inst->GetInOperand(i)); + Operand copy(inst->GetInOperand(i)); annotation->AddOperand(std::move(copy)); } context()->AddAnnotationInst(std::move(annotation)); @@ -309,18 +332,18 @@ void ScalarReplacementPass::TransferAnnotations( } void ScalarReplacementPass::CreateVariable( - uint32_t typeId, ir::Instruction* varInst, uint32_t index, - std::vector* replacements) { + uint32_t typeId, Instruction* varInst, uint32_t index, + std::vector* replacements) { uint32_t ptrId = GetOrCreatePointerType(typeId); uint32_t id = TakeNextId(); - std::unique_ptr variable(new ir::Instruction( + std::unique_ptr variable(new Instruction( context(), SpvOpVariable, ptrId, id, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}})); - ir::BasicBlock* block = context()->get_instr_block(varInst); + BasicBlock* block = context()->get_instr_block(varInst); block->begin().InsertBefore(std::move(variable)); - ir::Instruction* inst = &*block->begin(); + Instruction* inst = &*block->begin(); // If varInst was initialized, make sure to initialize its replacement. GetOrCreateInitialValue(varInst, index, inst); @@ -353,11 +376,8 @@ uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) { if (global.opcode() == SpvOpTypePointer && global.GetSingleWordInOperand(0u) == SpvStorageClassFunction && global.GetSingleWordInOperand(1u) == id) { - if (!context()->get_feature_mgr()->HasExtension( - libspirv::Extension::kSPV_KHR_variable_pointers) || - get_decoration_mgr()->GetDecorationsFor(id, false).empty()) { - // If variable pointers is enabled, only reuse a decoration-less - // pointer of the correct type. + if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) { + // Only reuse a decoration-less pointer of the correct type. ptrId = global.result_id(); break; } @@ -370,12 +390,12 @@ uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) { } ptrId = TakeNextId(); - context()->AddType(MakeUnique( + context()->AddType(MakeUnique( context(), SpvOpTypePointer, 0, ptrId, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}, {SPV_OPERAND_TYPE_ID, {id}}})); - ir::Instruction* ptr = &*--context()->types_values_end(); + Instruction* ptr = &*--context()->types_values_end(); get_def_use_mgr()->AnalyzeInstDefUse(ptr); pointee_to_pointer_[id] = ptrId; // Register with the type manager if necessary. @@ -384,15 +404,15 @@ uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) { return ptrId; } -void ScalarReplacementPass::GetOrCreateInitialValue(ir::Instruction* source, +void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source, uint32_t index, - ir::Instruction* newVar) { + Instruction* newVar) { assert(source->opcode() == SpvOpVariable); if (source->NumInOperands() < 2) return; uint32_t initId = source->GetSingleWordInOperand(1u); uint32_t storageId = GetStorageType(newVar)->result_id(); - ir::Instruction* init = get_def_use_mgr()->GetDef(initId); + Instruction* init = get_def_use_mgr()->GetDef(initId); uint32_t newInitId = 0; // TODO(dnovillo): Refactor this with constant propagation. if (init->opcode() == SpvOpConstantNull) { @@ -401,29 +421,29 @@ void ScalarReplacementPass::GetOrCreateInitialValue(ir::Instruction* source, if (iter == type_to_null_.end()) { newInitId = TakeNextId(); type_to_null_[storageId] = newInitId; - context()->AddGlobalValue(MakeUnique( - context(), SpvOpConstantNull, storageId, newInitId, - std::initializer_list{})); - ir::Instruction* newNull = &*--context()->types_values_end(); + context()->AddGlobalValue( + MakeUnique(context(), SpvOpConstantNull, storageId, + newInitId, std::initializer_list{})); + Instruction* newNull = &*--context()->types_values_end(); get_def_use_mgr()->AnalyzeInstDefUse(newNull); } else { newInitId = iter->second; } - } else if (ir::IsSpecConstantInst(init->opcode())) { + } else if (IsSpecConstantInst(init->opcode())) { // Create a new constant extract. newInitId = TakeNextId(); - context()->AddGlobalValue(MakeUnique( + context()->AddGlobalValue(MakeUnique( context(), SpvOpSpecConstantOp, storageId, newInitId, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}}, {SPV_OPERAND_TYPE_ID, {init->result_id()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}})); - ir::Instruction* newSpecConst = &*--context()->types_values_end(); + Instruction* newSpecConst = &*--context()->types_values_end(); get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst); } else if (init->opcode() == SpvOpConstantComposite) { // Get the appropriate index constant. newInitId = init->GetSingleWordInOperand(index); - ir::Instruction* element = get_def_use_mgr()->GetDef(newInitId); + Instruction* element = get_def_use_mgr()->GetDef(newInitId); if (element->opcode() == SpvOpUndef) { // Undef is not a valid initializer for a variable. newInitId = 0; @@ -437,7 +457,7 @@ void ScalarReplacementPass::GetOrCreateInitialValue(ir::Instruction* source, } } -size_t ScalarReplacementPass::GetIntegerLiteral(const ir::Operand& op) const { +size_t ScalarReplacementPass::GetIntegerLiteral(const Operand& op) const { assert(op.words.size() <= 2); size_t len = 0; for (uint32_t i = 0; i != op.words.size(); ++i) { @@ -447,7 +467,7 @@ size_t ScalarReplacementPass::GetIntegerLiteral(const ir::Operand& op) const { } size_t ScalarReplacementPass::GetConstantInteger( - const ir::Instruction* constant) const { + const Instruction* constant) const { assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() == SpvOpTypeInt); assert(constant->opcode() == SpvOpConstant || @@ -456,23 +476,22 @@ size_t ScalarReplacementPass::GetConstantInteger( return 0; } - const ir::Operand& op = constant->GetInOperand(0u); + const Operand& op = constant->GetInOperand(0u); return GetIntegerLiteral(op); } size_t ScalarReplacementPass::GetArrayLength( - const ir::Instruction* arrayType) const { + const Instruction* arrayType) const { assert(arrayType->opcode() == SpvOpTypeArray); - const ir::Instruction* length = + const Instruction* length = get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u)); return GetConstantInteger(length); } -size_t ScalarReplacementPass::GetNumElements( - const ir::Instruction* type) const { +size_t ScalarReplacementPass::GetNumElements(const Instruction* type) const { assert(type->opcode() == SpvOpTypeVector || type->opcode() == SpvOpTypeMatrix); - const ir::Operand& op = type->GetInOperand(1u); + const Operand& op = type->GetInOperand(1u); assert(op.words.size() <= 2); size_t len = 0; for (uint32_t i = 0; i != op.words.size(); ++i) { @@ -481,8 +500,8 @@ size_t ScalarReplacementPass::GetNumElements( return len; } -ir::Instruction* ScalarReplacementPass::GetStorageType( - const ir::Instruction* inst) const { +Instruction* ScalarReplacementPass::GetStorageType( + const Instruction* inst) const { assert(inst->opcode() == SpvOpVariable); uint32_t ptrTypeId = inst->type_id(); @@ -492,7 +511,7 @@ ir::Instruction* ScalarReplacementPass::GetStorageType( } bool ScalarReplacementPass::CanReplaceVariable( - const ir::Instruction* varInst) const { + const Instruction* varInst) const { assert(varInst->opcode() == SpvOpVariable); // Can only replace function scope variables. @@ -502,31 +521,31 @@ bool ScalarReplacementPass::CanReplaceVariable( if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) return false; - const ir::Instruction* typeInst = GetStorageType(varInst); + const Instruction* typeInst = GetStorageType(varInst); return CheckType(typeInst) && CheckAnnotations(varInst) && CheckUses(varInst); } -bool ScalarReplacementPass::CheckType(const ir::Instruction* typeInst) const { +bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const { if (!CheckTypeAnnotations(typeInst)) return false; switch (typeInst->opcode()) { case SpvOpTypeStruct: // Don't bother with empty structs or very large structs. if (typeInst->NumInOperands() == 0 || - typeInst->NumInOperands() > MAX_NUM_ELEMENTS) + IsLargerThanSizeLimit(typeInst->NumInOperands())) return false; return true; case SpvOpTypeArray: - if (GetArrayLength(typeInst) > MAX_NUM_ELEMENTS) return false; + if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) return false; return true; - // TODO(alanbaker): Develop some heuristics for when this should be - // re-enabled. - //// Specifically including matrix and vector in an attempt to reduce the - //// number of vector registers required. - // case SpvOpTypeMatrix: - // case SpvOpTypeVector: - // if (GetNumElements(typeInst) > MAX_NUM_ELEMENTS) return false; - // return true; + // TODO(alanbaker): Develop some heuristics for when this should be + // re-enabled. + //// Specifically including matrix and vector in an attempt to reduce the + //// number of vector registers required. + // case SpvOpTypeMatrix: + // case SpvOpTypeVector: + // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false; + // return true; case SpvOpTypeRuntimeArray: default: @@ -535,7 +554,7 @@ bool ScalarReplacementPass::CheckType(const ir::Instruction* typeInst) const { } bool ScalarReplacementPass::CheckTypeAnnotations( - const ir::Instruction* typeInst) const { + const Instruction* typeInst) const { for (auto inst : get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) { uint32_t decoration; @@ -567,8 +586,7 @@ bool ScalarReplacementPass::CheckTypeAnnotations( return true; } -bool ScalarReplacementPass::CheckAnnotations( - const ir::Instruction* varInst) const { +bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const { for (auto inst : get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) { assert(inst->opcode() == SpvOpDecorate); @@ -588,7 +606,7 @@ bool ScalarReplacementPass::CheckAnnotations( return true; } -bool ScalarReplacementPass::CheckUses(const ir::Instruction* inst) const { +bool ScalarReplacementPass::CheckUses(const Instruction* inst) const { VariableStats stats = {0, 0}; bool ok = CheckUses(inst, &stats); @@ -599,20 +617,20 @@ bool ScalarReplacementPass::CheckUses(const ir::Instruction* inst) const { return ok; } -bool ScalarReplacementPass::CheckUses(const ir::Instruction* inst, +bool ScalarReplacementPass::CheckUses(const Instruction* inst, VariableStats* stats) const { bool ok = true; get_def_use_mgr()->ForEachUse( - inst, [this, stats, &ok](const ir::Instruction* user, uint32_t index) { + inst, [this, stats, &ok](const Instruction* user, uint32_t index) { // Annotations are check as a group separately. - if (!ir::IsAnnotationInst(user->opcode())) { + if (!IsAnnotationInst(user->opcode())) { switch (user->opcode()) { case SpvOpAccessChain: case SpvOpInBoundsAccessChain: if (index == 2u) { uint32_t id = user->GetSingleWordOperand(3u); - const ir::Instruction* opInst = get_def_use_mgr()->GetDef(id); - if (!ir::IsCompileTimeConstantInst(opInst->opcode())) { + const Instruction* opInst = get_def_use_mgr()->GetDef(id); + if (!IsCompileTimeConstantInst(opInst->opcode())) { ok = false; } else { if (!CheckUsesRelaxed(user)) ok = false; @@ -643,11 +661,10 @@ bool ScalarReplacementPass::CheckUses(const ir::Instruction* inst, return ok; } -bool ScalarReplacementPass::CheckUsesRelaxed( - const ir::Instruction* inst) const { +bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const { bool ok = true; get_def_use_mgr()->ForEachUse( - inst, [this, &ok](const ir::Instruction* user, uint32_t index) { + inst, [this, &ok](const Instruction* user, uint32_t index) { switch (user->opcode()) { case SpvOpAccessChain: case SpvOpInBoundsAccessChain: @@ -672,7 +689,7 @@ bool ScalarReplacementPass::CheckUsesRelaxed( return ok; } -bool ScalarReplacementPass::CheckLoad(const ir::Instruction* inst, +bool ScalarReplacementPass::CheckLoad(const Instruction* inst, uint32_t index) const { if (index != 2u) return false; if (inst->NumInOperands() >= 2 && @@ -681,7 +698,7 @@ bool ScalarReplacementPass::CheckLoad(const ir::Instruction* inst, return true; } -bool ScalarReplacementPass::CheckStore(const ir::Instruction* inst, +bool ScalarReplacementPass::CheckStore(const Instruction* inst, uint32_t index) const { if (index != 0u) return false; if (inst->NumInOperands() >= 3 && @@ -689,6 +706,101 @@ bool ScalarReplacementPass::CheckStore(const ir::Instruction* inst, return false; return true; } +bool ScalarReplacementPass::IsLargerThanSizeLimit(size_t length) const { + if (max_num_elements_ == 0) { + return false; + } + return length > max_num_elements_; +} + +std::unique_ptr> +ScalarReplacementPass::GetUsedComponents(Instruction* inst) { + std::unique_ptr> result( + new std::unordered_set()); + + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + + def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr, + this](Instruction* use) { + switch (use->opcode()) { + case SpvOpLoad: { + // Look for extract from the load. + std::vector t; + if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) { + if (use2->opcode() != SpvOpCompositeExtract) { + return false; + } + t.push_back(use2->GetSingleWordInOperand(1)); + return true; + })) { + result->insert(t.begin(), t.end()); + return true; + } else { + result.reset(nullptr); + return false; + } + } + case SpvOpStore: + // No components are used. Things are just stored to. + return true; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: { + // Add the first index it if is a constant. + // TODO: Could be improved by checking if the address is used in a load. + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + uint32_t index_id = use->GetSingleWordInOperand(1); + const analysis::Constant* index_const = + const_mgr->FindDeclaredConstant(index_id); + if (index_const) { + const analysis::Integer* index_type = + index_const->type()->AsInteger(); + assert(index_type); + if (index_type->width() == 32) { + result->insert(index_const->GetU32()); + return true; + } else if (index_type->width() == 64) { + result->insert(index_const->GetU64()); + return true; + } + result.reset(nullptr); + return false; + } else { + // Could be any element. Assuming all are used. + result.reset(nullptr); + return false; + } + } + case SpvOpCopyObject: { + // Follow the copy to see which components are used. + auto t = GetUsedComponents(use); + if (!t) { + result.reset(nullptr); + return false; + } + result->insert(t->begin(), t->end()); + return true; + } + default: + // We do not know what is happening. Have to assume the worst. + result.reset(nullptr); + return false; + } + }); + + return result; +} + +Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + + const analysis::Type* type = type_mgr->GetType(type_id); + const analysis::Constant* null_const = const_mgr->GetConstant(type, {}); + Instruction* null_inst = + const_mgr->GetDefiningInstruction(null_const, type_id); + context()->UpdateDefUse(null_inst); + return null_inst; +} } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/scalar_replacement_pass.h b/3rdparty/spirv-tools/source/opt/scalar_replacement_pass.h index a48174ff7..c89bbc401 100644 --- a/3rdparty/spirv-tools/source/opt/scalar_replacement_pass.h +++ b/3rdparty/spirv-tools/source/opt/scalar_replacement_pass.h @@ -12,35 +12,47 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_SCALAR_REPLACEMENT_PASS_H_ -#define LIBSPIRV_OPT_SCALAR_REPLACEMENT_PASS_H_ - -#include "function.h" -#include "pass.h" -#include "type_manager.h" +#ifndef SOURCE_OPT_SCALAR_REPLACEMENT_PASS_H_ +#define SOURCE_OPT_SCALAR_REPLACEMENT_PASS_H_ +#include +#include #include +#include +#include +#include + +#include "source/opt/function.h" +#include "source/opt/pass.h" +#include "source/opt/type_manager.h" namespace spvtools { namespace opt { // Documented in optimizer.hpp class ScalarReplacementPass : public Pass { - public: - ScalarReplacementPass() = default; + private: + static const uint32_t kDefaultLimit = 100; - const char* name() const override { return "scalar-replacement"; } + public: + ScalarReplacementPass(uint32_t limit = kDefaultLimit) + : max_num_elements_(limit) { + name_[0] = '\0'; + strcat(name_, "scalar-replacement="); + sprintf(&name_[strlen(name_)], "%d", max_num_elements_); + } + + const char* name() const override { return name_; } // Attempts to scalarize all appropriate function scope variables. Returns // SuccessWithChange if any change is made. - Status Process(ir::IRContext* c) override; + Status Process() override; - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping | - ir::IRContext::kAnalysisDecorations | - ir::IRContext::kAnalysisCombinators | ir::IRContext::kAnalysisCFG | - ir::IRContext::kAnalysisNameMap; + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisNameMap; } private: @@ -54,27 +66,27 @@ class ScalarReplacementPass : public Pass { // Attempts to scalarize all appropriate function scope variables in // |function|. Returns SuccessWithChange if any changes are mode. - Status ProcessFunction(ir::Function* function); + Status ProcessFunction(Function* function); // Returns true if |varInst| can be scalarized. // // Examines the use chain of |varInst| to verify all uses are valid for // scalarization. - bool CanReplaceVariable(const ir::Instruction* varInst) const; + bool CanReplaceVariable(const Instruction* varInst) const; // Returns true if |typeInst| is an acceptable type to scalarize. // // Allows all aggregate types except runtime arrays. Additionally, checks the // that the number of elements that would be scalarized is within bounds. - bool CheckType(const ir::Instruction* typeInst) const; + bool CheckType(const Instruction* typeInst) const; // Returns true if all the decorations for |varInst| are acceptable for // scalarization. - bool CheckAnnotations(const ir::Instruction* varInst) const; + bool CheckAnnotations(const Instruction* varInst) const; // Returns true if all the decorations for |typeInst| are acceptable for // scalarization. - bool CheckTypeAnnotations(const ir::Instruction* typeInst) const; + bool CheckTypeAnnotations(const Instruction* typeInst) const; // Returns true if the uses of |inst| are acceptable for scalarization. // @@ -83,20 +95,20 @@ class ScalarReplacementPass : public Pass { // SpvOpStore. Access chains must have the first index be a compile-time // constant. Subsequent uses of access chains (including other access chains) // are checked in a more relaxed manner. - bool CheckUses(const ir::Instruction* inst) const; + bool CheckUses(const Instruction* inst) const; // Helper function for the above |CheckUses|. // // This version tracks some stats about the current OpVariable. These stats // are used to drive heuristics about when to scalarize. - bool CheckUses(const ir::Instruction* inst, VariableStats* stats) const; + bool CheckUses(const Instruction* inst, VariableStats* stats) const; // Relaxed helper function for |CheckUses|. - bool CheckUsesRelaxed(const ir::Instruction* inst) const; + bool CheckUsesRelaxed(const Instruction* inst) const; // Transfers appropriate decorations from |source| to |replacements|. - void TransferAnnotations(const ir::Instruction* source, - std::vector* replacements); + void TransferAnnotations(const Instruction* source, + std::vector* replacements); // Scalarizes |inst| and updates its uses. // @@ -106,31 +118,30 @@ class ScalarReplacementPass : public Pass { // get added to |worklist| for further processing. If any replacement // variable ends up with no uses it is erased. Returns false if any // subsequent access chain is out of bounds. - bool ReplaceVariable(ir::Instruction* inst, - std::queue* worklist); + bool ReplaceVariable(Instruction* inst, std::queue* worklist); // Returns the underlying storage type for |inst|. // // |inst| must be an OpVariable. Returns the type that is pointed to by // |inst|. - ir::Instruction* GetStorageType(const ir::Instruction* inst) const; + Instruction* GetStorageType(const Instruction* inst) const; // Returns true if the load can be scalarized. // // |inst| must be an OpLoad. Returns true if |index| is the pointer operand of // |inst| and the load is not from volatile memory. - bool CheckLoad(const ir::Instruction* inst, uint32_t index) const; + bool CheckLoad(const Instruction* inst, uint32_t index) const; // Returns true if the store can be scalarized. // // |inst| must be an OpStore. Returns true if |index| is the pointer operand // of |inst| and the store is not to volatile memory. - bool CheckStore(const ir::Instruction* inst, uint32_t index) const; + bool CheckStore(const Instruction* inst, uint32_t index) const; // Creates a variable of type |typeId| from the |index|'th element of // |varInst|. The new variable is added to |replacements|. - void CreateVariable(uint32_t typeId, ir::Instruction* varInst, uint32_t index, - std::vector* replacements); + void CreateVariable(uint32_t typeId, Instruction* varInst, uint32_t index, + std::vector* replacements); // Populates |replacements| with a new OpVariable for each element of |inst|. // @@ -139,24 +150,24 @@ class ScalarReplacementPass : public Pass { // will contain a variable for each element of the composite with matching // indexes (i.e. the 0'th element of |inst| is the 0'th entry of // |replacements|). - void CreateReplacementVariables(ir::Instruction* inst, - std::vector* replacements); + void CreateReplacementVariables(Instruction* inst, + std::vector* replacements); // Returns the value of an OpConstant of integer type. // // |constant| must use two or fewer words to generate the value. - size_t GetConstantInteger(const ir::Instruction* constant) const; + size_t GetConstantInteger(const Instruction* constant) const; // Returns the integer literal for |op|. - size_t GetIntegerLiteral(const ir::Operand& op) const; + size_t GetIntegerLiteral(const Operand& op) const; // Returns the array length for |arrayInst|. - size_t GetArrayLength(const ir::Instruction* arrayInst) const; + size_t GetArrayLength(const Instruction* arrayInst) const; // Returns the number of elements in |type|. // // |type| must be a vector or matrix type. - size_t GetNumElements(const ir::Instruction* type) const; + size_t GetNumElements(const Instruction* type) const; // Returns an id for a pointer to |id|. uint32_t GetOrCreatePointerType(uint32_t id); @@ -166,8 +177,8 @@ class ScalarReplacementPass : public Pass { // If there is an initial value for |source| for element |index|, it is // appended as an operand on |newVar|. If the initial value is OpUndef, no // initial value is added to |newVar|. - void GetOrCreateInitialValue(ir::Instruction* source, uint32_t index, - ir::Instruction* newVar); + void GetOrCreateInitialValue(Instruction* source, uint32_t index, + Instruction* newVar); // Replaces the load to the entire composite. // @@ -175,31 +186,47 @@ class ScalarReplacementPass : public Pass { // composite by combining all of the loads. // // |load| must be a load. - void ReplaceWholeLoad(ir::Instruction* load, - const std::vector& replacements); + void ReplaceWholeLoad(Instruction* load, + const std::vector& replacements); // Replaces the store to the entire composite. // // Generates a composite extract and store for each element in the scalarized // variable from the original store data input. - void ReplaceWholeStore(ir::Instruction* store, - const std::vector& replacements); + void ReplaceWholeStore(Instruction* store, + const std::vector& replacements); // Replaces an access chain to the composite variable with either a direct use // of the appropriate replacement variable or another access chain with the // replacement variable as the base and one fewer indexes. Returns false if // the chain has an out of bounds access. - bool ReplaceAccessChain(ir::Instruction* chain, - const std::vector& replacements); + bool ReplaceAccessChain(Instruction* chain, + const std::vector& replacements); + + // Returns a set containing the which components of the result of |inst| are + // potentially used. If the return value is |nullptr|, then every components + // is possibly used. + std::unique_ptr> GetUsedComponents( + Instruction* inst); + + // Returns an instruction defining a null constant with type |type_id|. If + // one already exists, it is returned. Otherwise a new one is created. + Instruction* CreateNullConstant(uint32_t type_id); // Maps storage type to a pointer type enclosing that type. std::unordered_map pointee_to_pointer_; // Maps type id to OpConstantNull for that type. std::unordered_map type_to_null_; + + // Limit on the number of members in an object that will be replaced. + // 0 means there is no limit. + uint32_t max_num_elements_; + bool IsLargerThanSizeLimit(size_t length) const; + char name_[55]; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_SCALAR_REPLACEMENT_PASS_H_ +#endif // SOURCE_OPT_SCALAR_REPLACEMENT_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/set_spec_constant_default_value_pass.cpp b/3rdparty/spirv-tools/source/opt/set_spec_constant_default_value_pass.cpp index bce78f9c6..4c8d116f7 100644 --- a/3rdparty/spirv-tools/source/opt/set_spec_constant_default_value_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/set_spec_constant_default_value_pass.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "set_spec_constant_default_value_pass.h" +#include "source/opt/set_spec_constant_default_value_pass.h" #include #include @@ -20,22 +20,22 @@ #include #include -#include "def_use_manager.h" -#include "ir_context.h" -#include "make_unique.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/type_manager.h" +#include "source/opt/types.h" +#include "source/util/make_unique.h" +#include "source/util/parse_number.h" #include "spirv-tools/libspirv.h" -#include "type_manager.h" -#include "types.h" -#include "util/parse_number.h" namespace spvtools { namespace opt { namespace { -using spvutils::EncodeNumberStatus; -using spvutils::NumberType; -using spvutils::ParseAndEncodeNumber; -using spvutils::ParseNumber; +using utils::EncodeNumberStatus; +using utils::NumberType; +using utils::ParseAndEncodeNumber; +using utils::ParseNumber; // Given a numeric value in a null-terminated c string and the expected type of // the value, parses the string and encodes it in a vector of words. If the @@ -112,7 +112,7 @@ std::vector ParseDefaultValueBitPattern( // Returns true if the given instruction's result id could have a SpecId // decoration. -bool CanHaveSpecIdDecoration(const ir::Instruction& inst) { +bool CanHaveSpecIdDecoration(const Instruction& inst) { switch (inst.opcode()) { case SpvOp::SpvOpSpecConstant: case SpvOp::SpvOpSpecConstantFalse: @@ -127,8 +127,8 @@ bool CanHaveSpecIdDecoration(const ir::Instruction& inst) { // decoration, finds the spec constant defining instruction which is the real // target of the SpecId decoration. Returns the spec constant defining // instruction if such an instruction is found, otherwise returns a nullptr. -ir::Instruction* GetSpecIdTargetFromDecorationGroup( - const ir::Instruction& decoration_group_defining_inst, +Instruction* GetSpecIdTargetFromDecorationGroup( + const Instruction& decoration_group_defining_inst, analysis::DefUseManager* def_use_mgr) { // Find the OpGroupDecorate instruction which consumes the given decoration // group. Note that the given decoration group has SpecId decoration, which @@ -136,9 +136,9 @@ ir::Instruction* GetSpecIdTargetFromDecorationGroup( // consumed by different OpGroupDecorate instructions. Therefore we only need // the first OpGroupDecoration instruction that uses the given decoration // group. - ir::Instruction* group_decorate_inst = nullptr; + Instruction* group_decorate_inst = nullptr; if (def_use_mgr->WhileEachUser(&decoration_group_defining_inst, - [&group_decorate_inst](ir::Instruction* user) { + [&group_decorate_inst](Instruction* user) { if (user->opcode() == SpvOp::SpvOpGroupDecorate) { group_decorate_inst = user; @@ -155,12 +155,12 @@ ir::Instruction* GetSpecIdTargetFromDecorationGroup( // instruction. If the OpGroupDecorate instruction has different target ids // or a target id is not defined by an eligible spec cosntant instruction, // returns a nullptr. - ir::Instruction* target_inst = nullptr; + Instruction* target_inst = nullptr; for (uint32_t i = 1; i < group_decorate_inst->NumInOperands(); i++) { // All the operands of a OpGroupDecorate instruction should be of type // SPV_OPERAND_TYPE_ID. uint32_t candidate_id = group_decorate_inst->GetSingleWordInOperand(i); - ir::Instruction* candidate_inst = def_use_mgr->GetDef(candidate_id); + Instruction* candidate_inst = def_use_mgr->GetDef(candidate_id); if (!candidate_inst) { continue; @@ -190,10 +190,7 @@ ir::Instruction* GetSpecIdTargetFromDecorationGroup( } } // namespace -Pass::Status SetSpecConstantDefaultValuePass::Process( - ir::IRContext* irContext) { - InitializeProcessing(irContext); - +Pass::Status SetSpecConstantDefaultValuePass::Process() { // The operand index of decoration target in an OpDecorate instruction. const uint32_t kTargetIdOperandIndex = 0; // The operand index of the decoration literal in an OpDecorate instruction. @@ -216,7 +213,7 @@ Pass::Status SetSpecConstantDefaultValuePass::Process( // is found for a spec id, the string will be parsed according to the target // spec constant type. The parsed value will be used to replace the original // default value of the target spec constant. - for (ir::Instruction& inst : irContext->annotations()) { + for (Instruction& inst : context()->annotations()) { // Only process 'OpDecorate SpecId' instructions if (inst.opcode() != SpvOp::SpvOpDecorate) continue; if (inst.NumOperands() != kOpDecorateSpecIdNumOperands) continue; @@ -231,8 +228,8 @@ Pass::Status SetSpecConstantDefaultValuePass::Process( // Find the spec constant defining instruction. Note that the // target_id might be a decoration group id. - ir::Instruction* spec_inst = nullptr; - if (ir::Instruction* target_inst = get_def_use_mgr()->GetDef(target_id)) { + Instruction* spec_inst = nullptr; + if (Instruction* target_inst = get_def_use_mgr()->GetDef(target_id)) { if (target_inst->opcode() == SpvOp::SpvOpDecorationGroup) { spec_inst = GetSpecIdTargetFromDecorationGroup(*target_inst, get_def_use_mgr()); diff --git a/3rdparty/spirv-tools/source/opt/set_spec_constant_default_value_pass.h b/3rdparty/spirv-tools/source/opt/set_spec_constant_default_value_pass.h index 95667bb89..8bd1787cc 100644 --- a/3rdparty/spirv-tools/source/opt/set_spec_constant_default_value_pass.h +++ b/3rdparty/spirv-tools/source/opt/set_spec_constant_default_value_pass.h @@ -12,16 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_SET_SPEC_CONSTANT_DEFAULT_VALUE_PASS_H_ -#define LIBSPIRV_OPT_SET_SPEC_CONSTANT_DEFAULT_VALUE_PASS_H_ +#ifndef SOURCE_OPT_SET_SPEC_CONSTANT_DEFAULT_VALUE_PASS_H_ +#define SOURCE_OPT_SET_SPEC_CONSTANT_DEFAULT_VALUE_PASS_H_ #include #include #include +#include +#include -#include "ir_context.h" -#include "module.h" -#include "pass.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -32,7 +34,7 @@ class SetSpecConstantDefaultValuePass : public Pass { using SpecIdToValueStrMap = std::unordered_map; using SpecIdToValueBitPatternMap = std::unordered_map>; - using SpecIdToInstMap = std::unordered_map; + using SpecIdToInstMap = std::unordered_map; // Constructs a pass instance with a map from spec ids to default values // in the form of string. @@ -56,7 +58,7 @@ class SetSpecConstantDefaultValuePass : public Pass { spec_id_to_value_bit_pattern_(std::move(default_values)) {} const char* name() const override { return "set-spec-const-default-value"; } - Status Process(ir::IRContext*) override; + Status Process() override; // Parses the given null-terminated C string to get a mapping from Spec Id to // default value strings. Returns a unique pointer of the mapping from spec @@ -109,4 +111,4 @@ class SetSpecConstantDefaultValuePass : public Pass { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_SET_SPEC_CONSTANT_DEFAULT_VALUE_PASS_H_ +#endif // SOURCE_OPT_SET_SPEC_CONSTANT_DEFAULT_VALUE_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/simplification_pass.cpp b/3rdparty/spirv-tools/source/opt/simplification_pass.cpp index 356ab90be..5fbafbdd1 100644 --- a/3rdparty/spirv-tools/source/opt/simplification_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/simplification_pass.cpp @@ -12,28 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "simplification_pass.h" +#include "source/opt/simplification_pass.h" #include #include #include -#include "fold.h" +#include "source/opt/fold.h" namespace spvtools { namespace opt { -Pass::Status SimplificationPass::Process(ir::IRContext* c) { - InitializeProcessing(c); +Pass::Status SimplificationPass::Process() { bool modified = false; - for (ir::Function& function : *get_module()) { + for (Function& function : *get_module()) { modified |= SimplifyFunction(&function); } return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); } -bool SimplificationPass::SimplifyFunction(ir::Function* function) { +bool SimplificationPass::SimplifyFunction(Function* function) { bool modified = false; // Phase 1: Traverse all instructions in dominance order. // The second phase will only be on the instructions whose inputs have changed @@ -41,27 +40,28 @@ bool SimplificationPass::SimplifyFunction(ir::Function* function) { // only instructions whose inputs do not necessarily dominate the use, we keep // track of the OpPhi instructions already seen, and add them to the work list // for phase 2 when needed. - std::vector work_list; - std::unordered_set process_phis; - std::unordered_set inst_to_kill; - std::unordered_set in_work_list; + std::vector work_list; + std::unordered_set process_phis; + std::unordered_set inst_to_kill; + std::unordered_set in_work_list; + const InstructionFolder& folder = context()->get_instruction_folder(); cfg()->ForEachBlockInReversePostOrder( function->entry().get(), [&modified, &process_phis, &work_list, &in_work_list, &inst_to_kill, - this](ir::BasicBlock* bb) { - for (ir::Instruction* inst = &*bb->begin(); inst; - inst = inst->NextNode()) { + folder, this](BasicBlock* bb) { + for (Instruction* inst = &*bb->begin(); inst; inst = inst->NextNode()) { if (inst->opcode() == SpvOpPhi) { process_phis.insert(inst); } - if (inst->opcode() == SpvOpCopyObject || FoldInstruction(inst)) { + if (inst->opcode() == SpvOpCopyObject || + folder.FoldInstruction(inst)) { modified = true; context()->AnalyzeUses(inst); get_def_use_mgr()->ForEachUser(inst, [&work_list, &process_phis, &in_work_list]( - ir::Instruction* use) { + Instruction* use) { if (process_phis.count(use) && in_work_list.insert(use).second) { work_list.push_back(use); } @@ -71,6 +71,9 @@ bool SimplificationPass::SimplifyFunction(ir::Function* function) { inst->GetSingleWordInOperand(0)); inst_to_kill.insert(inst); in_work_list.insert(inst); + } else if (inst->opcode() == SpvOpNop) { + inst_to_kill.insert(inst); + in_work_list.insert(inst); } } } @@ -80,13 +83,13 @@ bool SimplificationPass::SimplifyFunction(ir::Function* function) { // done. This time we add all users to the work list because phase 1 // has already finished. for (size_t i = 0; i < work_list.size(); ++i) { - ir::Instruction* inst = work_list[i]; + Instruction* inst = work_list[i]; in_work_list.erase(inst); - if (inst->opcode() == SpvOpCopyObject || FoldInstruction(inst)) { + if (inst->opcode() == SpvOpCopyObject || folder.FoldInstruction(inst)) { modified = true; context()->AnalyzeUses(inst); get_def_use_mgr()->ForEachUser( - inst, [&work_list, &in_work_list](ir::Instruction* use) { + inst, [&work_list, &in_work_list](Instruction* use) { if (!use->IsDecoration() && use->opcode() != SpvOpName && in_work_list.insert(use).second) { work_list.push_back(use); @@ -98,12 +101,15 @@ bool SimplificationPass::SimplifyFunction(ir::Function* function) { inst->GetSingleWordInOperand(0)); inst_to_kill.insert(inst); in_work_list.insert(inst); + } else if (inst->opcode() == SpvOpNop) { + inst_to_kill.insert(inst); + in_work_list.insert(inst); } } } // Phase 3: Kill instructions we know are no longer needed. - for (ir::Instruction* inst : inst_to_kill) { + for (Instruction* inst : inst_to_kill) { context()->KillInst(inst); } diff --git a/3rdparty/spirv-tools/source/opt/simplification_pass.h b/3rdparty/spirv-tools/source/opt/simplification_pass.h index 206d9dc8e..348c96a03 100644 --- a/3rdparty/spirv-tools/source/opt/simplification_pass.h +++ b/3rdparty/spirv-tools/source/opt/simplification_pass.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_SIMPLIFICATION_PASS_H_ -#define LIBSPIRV_OPT_SIMPLIFICATION_PASS_H_ +#ifndef SOURCE_OPT_SIMPLIFICATION_PASS_H_ +#define SOURCE_OPT_SIMPLIFICATION_PASS_H_ -#include "function.h" -#include "ir_context.h" -#include "pass.h" +#include "source/opt/function.h" +#include "source/opt/ir_context.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -26,24 +26,24 @@ namespace opt { class SimplificationPass : public Pass { public: const char* name() const override { return "simplify-instructions"; } - Status Process(ir::IRContext*) override; - virtual ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping | - ir::IRContext::kAnalysisDecorations | - ir::IRContext::kAnalysisCombinators | ir::IRContext::kAnalysisCFG | - ir::IRContext::kAnalysisDominatorAnalysis | - ir::IRContext::kAnalysisNameMap; + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap; } private: // Returns true if the module was changed. The simplifier is called on every // instruction in |function| until nothing else in the function can be // simplified. - bool SimplifyFunction(ir::Function* function); + bool SimplifyFunction(Function* function); }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_SIMPLIFICATION_PASS_H_ +#endif // SOURCE_OPT_SIMPLIFICATION_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/ssa_rewrite_pass.cpp b/3rdparty/spirv-tools/source/opt/ssa_rewrite_pass.cpp index 1e8be8836..83d243311 100644 --- a/3rdparty/spirv-tools/source/opt/ssa_rewrite_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/ssa_rewrite_pass.cpp @@ -39,14 +39,16 @@ // some Phi instructions may be dead // (https://en.wikipedia.org/wiki/Static_single_assignment_form). -#include "ssa_rewrite_pass.h" -#include "cfg.h" -#include "make_unique.h" -#include "mem_pass.h" -#include "opcode.h" +#include "source/opt/ssa_rewrite_pass.h" +#include #include +#include "source/opcode.h" +#include "source/opt/cfg.h" +#include "source/opt/mem_pass.h" +#include "source/util/make_unique.h" + // Debug logging (0: Off, 1-N: Verbosity level). Replace this with the // implementation done for // https://github.com/KhronosGroup/SPIRV-Tools/issues/1351 @@ -66,7 +68,7 @@ const uint32_t kStoreValIdInIdx = 1; const uint32_t kVariableInitIdInIdx = 1; } // namespace -std::string SSARewriter::PhiCandidate::PrettyPrint(const ir::CFG* cfg) const { +std::string SSARewriter::PhiCandidate::PrettyPrint(const CFG* cfg) const { std::ostringstream str; str << "%" << result_id_ << " = Phi[%" << var_id_ << ", BB %" << bb_->id() << "]("; @@ -87,7 +89,7 @@ std::string SSARewriter::PhiCandidate::PrettyPrint(const ir::CFG* cfg) const { } SSARewriter::PhiCandidate& SSARewriter::CreatePhiCandidate(uint32_t var_id, - ir::BasicBlock* bb) { + BasicBlock* bb) { uint32_t phi_result_id = pass_->context()->TakeNextId(); auto result = phi_candidates_.emplace( phi_result_id, PhiCandidate(var_id, phi_result_id, bb)); @@ -161,7 +163,7 @@ uint32_t SSARewriter::AddPhiOperands(PhiCandidate* phi_candidate) { bool found_0_arg = false; for (uint32_t pred : pass_->cfg()->preds(phi_candidate->bb()->id())) { - ir::BasicBlock* pred_bb = pass_->cfg()->block(pred); + BasicBlock* pred_bb = pass_->cfg()->block(pred); // If |pred_bb| is not sealed, use %0 to indicate that // |phi_candidate| needs to be completed after the whole CFG has @@ -233,7 +235,7 @@ uint32_t SSARewriter::AddPhiOperands(PhiCandidate* phi_candidate) { return repl_id; } -uint32_t SSARewriter::GetReachingDef(uint32_t var_id, ir::BasicBlock* bb) { +uint32_t SSARewriter::GetReachingDef(uint32_t var_id, BasicBlock* bb) { // If |var_id| has a definition in |bb|, return it. const auto& bb_it = defs_at_block_.find(bb); if (bb_it != defs_at_block_.end()) { @@ -271,14 +273,14 @@ uint32_t SSARewriter::GetReachingDef(uint32_t var_id, ir::BasicBlock* bb) { return val_id; } -void SSARewriter::SealBlock(ir::BasicBlock* bb) { +void SSARewriter::SealBlock(BasicBlock* bb) { auto result = sealed_blocks_.insert(bb); (void)result; assert(result.second == true && "Tried to seal the same basic block more than once."); } -void SSARewriter::ProcessStore(ir::Instruction* inst, ir::BasicBlock* bb) { +void SSARewriter::ProcessStore(Instruction* inst, BasicBlock* bb) { auto opcode = inst->opcode(); assert((opcode == SpvOpStore || opcode == SpvOpVariable) && "Expecting a store or a variable definition instruction."); @@ -303,7 +305,7 @@ void SSARewriter::ProcessStore(ir::Instruction* inst, ir::BasicBlock* bb) { } } -void SSARewriter::ProcessLoad(ir::Instruction* inst, ir::BasicBlock* bb) { +void SSARewriter::ProcessLoad(Instruction* inst, BasicBlock* bb) { uint32_t var_id = 0; (void)pass_->GetPtr(inst, &var_id); if (pass_->IsTargetVar(var_id)) { @@ -346,7 +348,7 @@ void SSARewriter::PrintReplacementTable() const { std::cerr << "\n"; } -void SSARewriter::GenerateSSAReplacements(ir::BasicBlock* bb) { +void SSARewriter::GenerateSSAReplacements(BasicBlock* bb) { #if SSA_REWRITE_DEBUGGING_LEVEL > 1 std::cerr << "Generating SSA replacements for block: " << bb->id() << "\n"; std::cerr << bb->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES) @@ -416,7 +418,7 @@ bool SSARewriter::ApplyReplacements() { #endif // Add Phi instructions from completed Phi candidates. - std::vector generated_phis; + std::vector generated_phis; for (const PhiCandidate* phi_candidate : phis_to_generate_) { #if SSA_REWRITE_DEBUGGING_LEVEL > 2 std::cerr << "Phi candidate: " << phi_candidate->PrettyPrint(pass_->cfg()) @@ -430,7 +432,7 @@ bool SSARewriter::ApplyReplacements() { // Build the vector of operands for the new OpPhi instruction. uint32_t type_id = pass_->GetPointeeTypeId( pass_->get_def_use_mgr()->GetDef(phi_candidate->var_id())); - std::vector phi_operands; + std::vector phi_operands; uint32_t arg_ix = 0; for (uint32_t pred_label : pass_->cfg()->preds(phi_candidate->bb()->id())) { uint32_t op_val_id = GetPhiArgument(phi_candidate, arg_ix++); @@ -442,21 +444,26 @@ bool SSARewriter::ApplyReplacements() { // Generate a new OpPhi instruction and insert it in its basic // block. - std::unique_ptr phi_inst( - new ir::Instruction(pass_->context(), SpvOpPhi, type_id, - phi_candidate->result_id(), phi_operands)); + std::unique_ptr phi_inst( + new Instruction(pass_->context(), SpvOpPhi, type_id, + phi_candidate->result_id(), phi_operands)); generated_phis.push_back(phi_inst.get()); pass_->get_def_use_mgr()->AnalyzeInstDef(&*phi_inst); pass_->context()->set_instr_block(&*phi_inst, phi_candidate->bb()); auto insert_it = phi_candidate->bb()->begin(); insert_it.InsertBefore(std::move(phi_inst)); + + pass_->context()->get_decoration_mgr()->CloneDecorations( + phi_candidate->var_id(), phi_candidate->result_id(), + {SpvDecorationRelaxedPrecision}); + modified = true; } // Scan uses for all inserted Phi instructions. Do this separately from the // registration of the Phi instruction itself to avoid trying to analyze uses // of Phi instructions that have not been registered yet. - for (ir::Instruction* phi_inst : generated_phis) { + for (Instruction* phi_inst : generated_phis) { pass_->get_def_use_mgr()->AnalyzeInstUse(&*phi_inst); } @@ -469,7 +476,7 @@ bool SSARewriter::ApplyReplacements() { for (auto& repl : load_replacement_) { uint32_t load_id = repl.first; uint32_t val_id = GetReplacement(repl); - ir::Instruction* load_inst = + Instruction* load_inst = pass_->context()->get_def_use_mgr()->GetDef(load_id); #if SSA_REWRITE_DEBUGGING_LEVEL > 2 @@ -498,7 +505,7 @@ void SSARewriter::FinalizePhiCandidate(PhiCandidate* phi_candidate) { uint32_t ix = 0; for (uint32_t pred : pass_->cfg()->preds(phi_candidate->bb()->id())) { - ir::BasicBlock* pred_bb = pass_->cfg()->block(pred); + BasicBlock* pred_bb = pass_->cfg()->block(pred); uint32_t& arg_id = phi_candidate->phi_args()[ix++]; if (arg_id == 0) { // If |pred_bb| is still not sealed, it means it's unreachable. In this @@ -536,7 +543,7 @@ void SSARewriter::FinalizePhiCandidates() { } } -bool SSARewriter::RewriteFunctionIntoSSA(ir::Function* fp) { +bool SSARewriter::RewriteFunctionIntoSSA(Function* fp) { #if SSA_REWRITE_DEBUGGING_LEVEL > 0 std::cerr << "Function before SSA rewrite:\n" << fp->PrettyPrint(0) << "\n\n\n"; @@ -549,7 +556,7 @@ bool SSARewriter::RewriteFunctionIntoSSA(ir::Function* fp) { // generate incomplete and trivial Phis. pass_->cfg()->ForEachBlockInReversePostOrder( fp->entry().get(), - [this](ir::BasicBlock* bb) { GenerateSSAReplacements(bb); }); + [this](BasicBlock* bb) { GenerateSSAReplacements(bb); }); // Remove trivial Phis and add arguments to incomplete Phis. FinalizePhiCandidates(); @@ -565,11 +572,7 @@ bool SSARewriter::RewriteFunctionIntoSSA(ir::Function* fp) { return modified; } -void SSARewritePass::Initialize(ir::IRContext* c) { InitializeProcessing(c); } - -Pass::Status SSARewritePass::Process(ir::IRContext* c) { - Initialize(c); - +Pass::Status SSARewritePass::Process() { bool modified = false; for (auto& fn : *get_module()) { modified |= SSARewriter(this).RewriteFunctionIntoSSA(&fn); diff --git a/3rdparty/spirv-tools/source/opt/ssa_rewrite_pass.h b/3rdparty/spirv-tools/source/opt/ssa_rewrite_pass.h index e58943635..c0373dc06 100644 --- a/3rdparty/spirv-tools/source/opt/ssa_rewrite_pass.h +++ b/3rdparty/spirv-tools/source/opt/ssa_rewrite_pass.h @@ -12,14 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_SSA_REWRITE_PASS_H_ -#define LIBSPIRV_OPT_SSA_REWRITE_PASS_H_ - -#include "basic_block.h" -#include "ir_context.h" -#include "mem_pass.h" +#ifndef SOURCE_OPT_SSA_REWRITE_PASS_H_ +#define SOURCE_OPT_SSA_REWRITE_PASS_H_ +#include +#include #include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/ir_context.h" +#include "source/opt/mem_pass.h" namespace spvtools { namespace opt { @@ -43,12 +48,12 @@ class SSARewriter { // // It returns true if function |fp| was modified. Otherwise, it returns // false. - bool RewriteFunctionIntoSSA(ir::Function* fp); + bool RewriteFunctionIntoSSA(Function* fp); private: class PhiCandidate { public: - explicit PhiCandidate(uint32_t var, uint32_t result, ir::BasicBlock* block) + explicit PhiCandidate(uint32_t var, uint32_t result, BasicBlock* block) : var_id_(var), result_id_(result), bb_(block), @@ -59,7 +64,7 @@ class SSARewriter { uint32_t var_id() const { return var_id_; } uint32_t result_id() const { return result_id_; } - ir::BasicBlock* bb() const { return bb_; } + BasicBlock* bb() const { return bb_; } std::vector& phi_args() { return phi_args_; } const std::vector& phi_args() const { return phi_args_; } uint32_t copy_of() const { return copy_of_; } @@ -81,7 +86,7 @@ class SSARewriter { // Pretty prints this Phi candidate into a string and returns it. |cfg| is // needed to lookup basic block predecessors. - std::string PrettyPrint(const ir::CFG* cfg) const; + std::string PrettyPrint(const CFG* cfg) const; // Registers |operand_id| as a user of this Phi candidate. void AddUser(uint32_t operand_id) { users_.push_back(operand_id); } @@ -95,7 +100,7 @@ class SSARewriter { uint32_t result_id_; // Basic block to hold this Phi. - ir::BasicBlock* bb_; + BasicBlock* bb_; // Vector of operands for every predecessor block of |bb|. This vector is // organized so that the Ith slot contains the argument coming from the Ith @@ -117,23 +122,21 @@ class SSARewriter { }; // Type used to keep track of store operations in each basic block. - typedef std::unordered_map> BlockDefsMap; // Generates all the SSA rewriting decisions for basic block |bb|. This // populates the Phi candidate table (|phi_candidate_|) and the load // replacement table (|load_replacement_). - void GenerateSSAReplacements(ir::BasicBlock* bb); + void GenerateSSAReplacements(BasicBlock* bb); // Seals block |bb|. Sealing a basic block means |bb| and all its // predecessors of |bb| have been scanned for loads/stores. - void SealBlock(ir::BasicBlock* bb); + void SealBlock(BasicBlock* bb); // Returns true if |bb| has been sealed. - bool IsBlockSealed(ir::BasicBlock* bb) { - return sealed_blocks_.count(bb) != 0; - } + bool IsBlockSealed(BasicBlock* bb) { return sealed_blocks_.count(bb) != 0; } // Returns the Phi candidate with result ID |id| if it exists in the table // |phi_candidates_|. If no such Phi candidate exists, it returns nullptr. @@ -183,7 +186,7 @@ class SSARewriter { // Registers a definition for variable |var_id| in basic block |bb| with // value |val_id|. - void WriteVariable(uint32_t var_id, ir::BasicBlock* bb, uint32_t val_id) { + void WriteVariable(uint32_t var_id, BasicBlock* bb, uint32_t val_id) { defs_at_block_[bb][var_id] = val_id; } @@ -191,13 +194,13 @@ class SSARewriter { // the variable ID being stored into, determines whether the variable is an // SSA-target variable, and, if it is, it stores its value in the // |defs_at_block_| map. - void ProcessStore(ir::Instruction* inst, ir::BasicBlock* bb); + void ProcessStore(Instruction* inst, BasicBlock* bb); // Processes the load operation |inst| in basic block |bb|. This extracts // the variable ID being stored into, determines whether the variable is an // SSA-target variable, and, if it is, it reads its reaching definition by // calling |GetReachingDef|. - void ProcessLoad(ir::Instruction* inst, ir::BasicBlock* bb); + void ProcessLoad(Instruction* inst, BasicBlock* bb); // Reads the current definition for variable |var_id| in basic block |bb|. // If |var_id| is not defined in block |bb| it walks up the predecessors of @@ -205,7 +208,7 @@ class SSARewriter { // // It returns the value for |var_id| from the RHS of the current reaching // definition for |var_id|. - uint32_t GetReachingDef(uint32_t var_id, ir::BasicBlock* bb); + uint32_t GetReachingDef(uint32_t var_id, BasicBlock* bb); // Adds arguments to |phi_candidate| by getting the reaching definition of // |phi_candidate|'s variable on each of the predecessors of its basic @@ -223,7 +226,7 @@ class SSARewriter { // during rewriting. // // Once the candidate Phi is created, it returns its ID. - PhiCandidate& CreatePhiCandidate(uint32_t var_id, ir::BasicBlock* bb); + PhiCandidate& CreatePhiCandidate(uint32_t var_id, BasicBlock* bb); // Attempts to remove a trivial Phi candidate |phi_cand|. Trivial Phis are // those that only reference themselves and one other value |val| any number @@ -277,7 +280,7 @@ class SSARewriter { std::unordered_map load_replacement_; // Set of blocks that have been sealed already. - std::unordered_set sealed_blocks_; + std::unordered_set sealed_blocks_; // Memory pass requesting the SSA rewriter. MemPass* pass_; @@ -290,15 +293,12 @@ class SSARewriter { class SSARewritePass : public MemPass { public: SSARewritePass() = default; - const char* name() const override { return "ssa-rewrite"; } - Status Process(ir::IRContext* c) override; - private: - // Initializes the pass. - void Initialize(ir::IRContext* c); + const char* name() const override { return "ssa-rewrite"; } + Status Process() override; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_SSA_REWRITE_PASS_H_ +#endif // SOURCE_OPT_SSA_REWRITE_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/strength_reduction_pass.cpp b/3rdparty/spirv-tools/source/opt/strength_reduction_pass.cpp index fd8ccf97c..ab7c4eb8d 100644 --- a/3rdparty/spirv-tools/source/opt/strength_reduction_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/strength_reduction_pass.cpp @@ -12,18 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "strength_reduction_pass.h" +#include "source/opt/strength_reduction_pass.h" #include #include #include +#include #include #include +#include +#include -#include "def_use_manager.h" -#include "ir_context.h" -#include "log.h" -#include "reflect.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/log.h" +#include "source/opt/reflect.h" namespace { // Count the number of trailing zeros in the binary representation of @@ -53,9 +56,7 @@ bool IsPowerOf2(uint32_t val) { namespace spvtools { namespace opt { -Pass::Status StrengthReductionPass::Process(ir::IRContext* c) { - InitializeProcessing(c); - +Pass::Status StrengthReductionPass::Process() { // Initialize the member variables on a per module basis. bool modified = false; int32_type_id_ = 0; @@ -68,7 +69,7 @@ Pass::Status StrengthReductionPass::Process(ir::IRContext* c) { } bool StrengthReductionPass::ReplaceMultiplyByPowerOf2( - ir::BasicBlock::iterator* inst) { + BasicBlock::iterator* inst) { assert((*inst)->opcode() == SpvOp::SpvOpIMul && "Only works for multiplication of integers."); bool modified = false; @@ -82,7 +83,7 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2( // Check the operands for a constant that is a power of 2. for (int i = 0; i < 2; i++) { uint32_t opId = (*inst)->GetSingleWordInOperand(i); - ir::Instruction* opInst = get_def_use_mgr()->GetDef(opId); + Instruction* opInst = get_def_use_mgr()->GetDef(opId); if (opInst->opcode() == SpvOp::SpvOpConstant) { // We found a constant operand. uint32_t constVal = opInst->GetSingleWordOperand(2); @@ -94,14 +95,14 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2( // Create the new instruction. uint32_t newResultId = TakeNextId(); - std::vector newOperands; + std::vector newOperands; newOperands.push_back((*inst)->GetInOperand(1 - i)); - ir::Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, - {shiftConstResultId}); + Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, + {shiftConstResultId}); newOperands.push_back(shiftOperand); - std::unique_ptr newInstruction( - new ir::Instruction(context(), SpvOp::SpvOpShiftLeftLogical, - (*inst)->type_id(), newResultId, newOperands)); + std::unique_ptr newInstruction( + new Instruction(context(), SpvOp::SpvOpShiftLeftLogical, + (*inst)->type_id(), newResultId, newOperands)); // Insert the new instruction and update the data structures. (*inst) = (*inst).InsertBefore(std::move(newInstruction)); @@ -110,7 +111,7 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2( context()->ReplaceAllUsesWith((*inst)->result_id(), newResultId); // Remove the old instruction. - ir::Instruction* inst_to_delete = &*(*inst); + Instruction* inst_to_delete = &*(*inst); --(*inst); context()->KillInst(inst_to_delete); @@ -156,11 +157,11 @@ uint32_t StrengthReductionPass::GetConstantId(uint32_t val) { // Construct the constant. uint32_t resultId = TakeNextId(); - ir::Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, - {val}); - std::unique_ptr newConstant( - new ir::Instruction(context(), SpvOp::SpvOpConstant, uint32_type_id_, - resultId, {constant})); + Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, + {val}); + std::unique_ptr newConstant( + new Instruction(context(), SpvOp::SpvOpConstant, uint32_type_id_, + resultId, {constant})); get_module()->AddGlobalValue(std::move(newConstant)); // Notify the DefUseManager about this constant. diff --git a/3rdparty/spirv-tools/source/opt/strength_reduction_pass.h b/3rdparty/spirv-tools/source/opt/strength_reduction_pass.h index 6c233e151..8dfeb307b 100644 --- a/3rdparty/spirv-tools/source/opt/strength_reduction_pass.h +++ b/3rdparty/spirv-tools/source/opt/strength_reduction_pass.h @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_STRENGTH_REDUCTION_PASS_H_ -#define LIBSPIRV_OPT_STRENGTH_REDUCTION_PASS_H_ +#ifndef SOURCE_OPT_STRENGTH_REDUCTION_PASS_H_ +#define SOURCE_OPT_STRENGTH_REDUCTION_PASS_H_ -#include "def_use_manager.h" -#include "ir_context.h" -#include "module.h" -#include "pass.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -27,12 +27,12 @@ namespace opt { class StrengthReductionPass : public Pass { public: const char* name() const override { return "strength-reduction"; } - Status Process(ir::IRContext*) override; + Status Process() override; private: // Replaces multiple by power of 2 with an equivalent bit shift. // Returns true if something changed. - bool ReplaceMultiplyByPowerOf2(ir::BasicBlock::iterator*); + bool ReplaceMultiplyByPowerOf2(BasicBlock::iterator*); // Scan the types and constants in the module looking for the the integer // types that we are @@ -62,4 +62,4 @@ class StrengthReductionPass : public Pass { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_STRENGTH_REDUCTION_PASS_H_ +#endif // SOURCE_OPT_STRENGTH_REDUCTION_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/strip_debug_info_pass.cpp b/3rdparty/spirv-tools/source/opt/strip_debug_info_pass.cpp index ae35b1096..5d9c5fec8 100644 --- a/3rdparty/spirv-tools/source/opt/strip_debug_info_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/strip_debug_info_pass.cpp @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "strip_debug_info_pass.h" -#include "ir_context.h" +#include "source/opt/strip_debug_info_pass.h" +#include "source/opt/ir_context.h" namespace spvtools { namespace opt { -Pass::Status StripDebugInfoPass::Process(ir::IRContext* irContext) { - bool modified = !irContext->debugs1().empty() || - !irContext->debugs2().empty() || - !irContext->debugs3().empty(); - irContext->debug_clear(); +Pass::Status StripDebugInfoPass::Process() { + bool modified = !context()->debugs1().empty() || + !context()->debugs2().empty() || + !context()->debugs3().empty(); + context()->debug_clear(); - irContext->module()->ForEachInst([&modified](ir::Instruction* inst) { + context()->module()->ForEachInst([&modified](Instruction* inst) { modified |= !inst->dbg_line_insts().empty(); inst->dbg_line_insts().clear(); }); diff --git a/3rdparty/spirv-tools/source/opt/strip_debug_info_pass.h b/3rdparty/spirv-tools/source/opt/strip_debug_info_pass.h index 52cbd680a..47a2cd409 100644 --- a/3rdparty/spirv-tools/source/opt/strip_debug_info_pass.h +++ b/3rdparty/spirv-tools/source/opt/strip_debug_info_pass.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_STRIP_DEBUG_INFO_PASS_H_ -#define LIBSPIRV_OPT_STRIP_DEBUG_INFO_PASS_H_ +#ifndef SOURCE_OPT_STRIP_DEBUG_INFO_PASS_H_ +#define SOURCE_OPT_STRIP_DEBUG_INFO_PASS_H_ -#include "ir_context.h" -#include "module.h" -#include "pass.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -26,10 +26,10 @@ namespace opt { class StripDebugInfoPass : public Pass { public: const char* name() const override { return "strip-debug"; } - Status Process(ir::IRContext* irContext) override; + Status Process() override; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_STRIP_DEBUG_INFO_PASS_H_ +#endif // SOURCE_OPT_STRIP_DEBUG_INFO_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/strip_reflect_info_pass.cpp b/3rdparty/spirv-tools/source/opt/strip_reflect_info_pass.cpp index d863e6600..14ce31ff3 100644 --- a/3rdparty/spirv-tools/source/opt/strip_reflect_info_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/strip_reflect_info_pass.cpp @@ -12,25 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "strip_reflect_info_pass.h" +#include "source/opt/strip_reflect_info_pass.h" #include +#include -#include "instruction.h" -#include "ir_context.h" +#include "source/opt/instruction.h" +#include "source/opt/ir_context.h" namespace spvtools { namespace opt { -using spvtools::ir::Instruction; - -Pass::Status StripReflectInfoPass::Process(ir::IRContext* irContext) { +Pass::Status StripReflectInfoPass::Process() { bool modified = false; std::vector to_remove; bool other_uses_for_decorate_string = false; - for (auto& inst : irContext->module()->annotations()) { + for (auto& inst : context()->module()->annotations()) { switch (inst.opcode()) { case SpvOpDecorateStringGOOGLE: if (inst.GetSingleWordInOperand(1) == SpvDecorationHlslSemanticGOOGLE) { @@ -52,7 +51,7 @@ Pass::Status StripReflectInfoPass::Process(ir::IRContext* irContext) { } } - for (auto& inst : irContext->module()->extensions()) { + for (auto& inst : context()->module()->extensions()) { const char* ext_name = reinterpret_cast(&inst.GetInOperand(0).words[0]); if (0 == std::strcmp(ext_name, "SPV_GOOGLE_hlsl_functionality1")) { @@ -65,7 +64,7 @@ Pass::Status StripReflectInfoPass::Process(ir::IRContext* irContext) { for (auto* inst : to_remove) { modified = true; - irContext->KillInst(inst); + context()->KillInst(inst); } return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; diff --git a/3rdparty/spirv-tools/source/opt/strip_reflect_info_pass.h b/3rdparty/spirv-tools/source/opt/strip_reflect_info_pass.h index b6e9f33fb..935a605e3 100644 --- a/3rdparty/spirv-tools/source/opt/strip_reflect_info_pass.h +++ b/3rdparty/spirv-tools/source/opt/strip_reflect_info_pass.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_STRIP_REFLECT_INFO_PASS_H_ -#define LIBSPIRV_OPT_STRIP_REFLECT_INFO_PASS_H_ +#ifndef SOURCE_OPT_STRIP_REFLECT_INFO_PASS_H_ +#define SOURCE_OPT_STRIP_REFLECT_INFO_PASS_H_ -#include "ir_context.h" -#include "module.h" -#include "pass.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -26,19 +26,18 @@ namespace opt { class StripReflectInfoPass : public Pass { public: const char* name() const override { return "strip-reflect"; } - Status Process(ir::IRContext* irContext) override; + Status Process() override; // Return the mask of preserved Analyses. - ir::IRContext::Analysis GetPreservedAnalyses() override { - return ir::IRContext::kAnalysisInstrToBlockMapping | - ir::IRContext::kAnalysisCombinators | ir::IRContext::kAnalysisCFG | - ir::IRContext::kAnalysisDominatorAnalysis | - ir::IRContext::kAnalysisLoopAnalysis | - ir::IRContext::kAnalysisNameMap; + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisCombinators | IRContext::kAnalysisCFG | + IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap; } }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_STRIP_REFLECT_INFO_PASS_H_ +#endif // SOURCE_OPT_STRIP_REFLECT_INFO_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/tree_iterator.h b/3rdparty/spirv-tools/source/opt/tree_iterator.h index ba724dfa5..05f42bc5b 100644 --- a/3rdparty/spirv-tools/source/opt/tree_iterator.h +++ b/3rdparty/spirv-tools/source/opt/tree_iterator.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_TREE_ITERATOR_H_ -#define LIBSPIRV_OPT_TREE_ITERATOR_H_ +#ifndef SOURCE_OPT_TREE_ITERATOR_H_ +#define SOURCE_OPT_TREE_ITERATOR_H_ #include #include @@ -243,4 +243,4 @@ class PostOrderTreeDFIterator { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_TREE_ITERATOR_H_ +#endif // SOURCE_OPT_TREE_ITERATOR_H_ diff --git a/3rdparty/spirv-tools/source/opt/type_manager.cpp b/3rdparty/spirv-tools/source/opt/type_manager.cpp index 9e2cd8652..bd5221b04 100644 --- a/3rdparty/spirv-tools/source/opt/type_manager.cpp +++ b/3rdparty/spirv-tools/source/opt/type_manager.cpp @@ -12,28 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "type_manager.h" +#include "source/opt/type_manager.h" +#include #include #include #include -#include "ir_context.h" -#include "log.h" -#include "make_unique.h" -#include "reflect.h" - -namespace { -const int kSpvTypePointerStorageClass = 1; -const int kSpvTypePointerTypeIdInIdx = 2; -} // namespace +#include "source/opt/ir_context.h" +#include "source/opt/log.h" +#include "source/opt/reflect.h" +#include "source/util/make_unique.h" namespace spvtools { namespace opt { namespace analysis { +namespace { -TypeManager::TypeManager(const MessageConsumer& consumer, - spvtools::ir::IRContext* c) +const int kSpvTypePointerStorageClass = 1; +const int kSpvTypePointerTypeIdInIdx = 2; + +} // namespace + +TypeManager::TypeManager(const MessageConsumer& consumer, IRContext* c) : consumer_(consumer), context_(c) { AnalyzeTypes(*c->module()); } @@ -41,6 +42,8 @@ TypeManager::TypeManager(const MessageConsumer& consumer, Type* TypeManager::GetType(uint32_t id) const { auto iter = id_to_type_.find(id); if (iter != id_to_type_.end()) return (*iter).second; + iter = id_to_incomplete_type_.find(id); + if (iter != id_to_incomplete_type_.end()) return (*iter).second; return nullptr; } @@ -48,9 +51,9 @@ std::pair> TypeManager::GetTypeAndPointerType( uint32_t id, SpvStorageClass sc) const { Type* type = GetType(id); if (type) { - return std::make_pair(type, MakeUnique(type, sc)); + return std::make_pair(type, MakeUnique(type, sc)); } else { - return std::make_pair(type, std::unique_ptr()); + return std::make_pair(type, std::unique_ptr()); } } @@ -60,13 +63,107 @@ uint32_t TypeManager::GetId(const Type* type) const { return 0; } -ForwardPointer* TypeManager::GetForwardPointer(uint32_t index) const { - if (index >= forward_pointers_.size()) return nullptr; - return forward_pointers_.at(index).get(); -} +void TypeManager::AnalyzeTypes(const Module& module) { + // First pass through the types. Any types that reference a forward pointer + // (directly or indirectly) are incomplete, and are added to incomplete types. + for (const auto* inst : module.GetTypes()) { + RecordIfTypeDefinition(*inst); + } -void TypeManager::AnalyzeTypes(const spvtools::ir::Module& module) { - for (const auto* inst : module.GetTypes()) RecordIfTypeDefinition(*inst); + if (incomplete_types_.empty()) { + return; + } + + // Get the real pointer definition for all of the forward pointers. + for (auto& type : incomplete_types_) { + if (type.type()->kind() == Type::kForwardPointer) { + auto* t = GetType(type.id()); + assert(t); + auto* p = t->AsPointer(); + assert(p); + type.type()->AsForwardPointer()->SetTargetPointer(p); + } + } + + // Replaces the references to the forward pointers in the incomplete types. + for (auto& type : incomplete_types_) { + ReplaceForwardPointers(type.type()); + } + + // Delete the forward pointers now that they are not referenced anymore. + for (auto& type : incomplete_types_) { + if (type.type()->kind() == Type::kForwardPointer) { + type.ResetType(nullptr); + } + } + + // Compare the complete types looking for types that are the same. If there + // are two types that are the same, then replace one with the other. + // Continue until we reach a fixed point. + bool restart = true; + while (restart) { + restart = false; + for (auto it1 = incomplete_types_.begin(); it1 != incomplete_types_.end(); + ++it1) { + uint32_t id1 = it1->id(); + Type* type1 = it1->type(); + if (!type1) { + continue; + } + + for (auto it2 = it1 + 1; it2 != incomplete_types_.end(); ++it2) { + uint32_t id2 = it2->id(); + (void)(id2 + id1); + Type* type2 = it2->type(); + if (!type2) { + continue; + } + + if (type1->IsSame(type2)) { + ReplaceType(type1, type2); + it2->ResetType(nullptr); + id_to_incomplete_type_[it2->id()] = type1; + restart = true; + } + } + } + } + + // Add the remaining incomplete types to the type pool. + for (auto& type : incomplete_types_) { + if (type.type() && !type.type()->AsForwardPointer()) { + std::vector decorations = + context()->get_decoration_mgr()->GetDecorationsFor(type.id(), true); + for (auto dec : decorations) { + AttachDecoration(*dec, type.type()); + } + auto pair = type_pool_.insert(type.ReleaseType()); + id_to_type_[type.id()] = pair.first->get(); + type_to_id_[pair.first->get()] = type.id(); + id_to_incomplete_type_.erase(type.id()); + } + } + + // Add a mapping for any ids that whose original type was replaced by an + // equivalent type. + for (auto& type : id_to_incomplete_type_) { + id_to_type_[type.first] = type.second; + } + +#ifndef NDEBUG + // Check if the type pool contains two types that are the same. This + // is an indication that the hashing and comparision are wrong. It + // will cause a problem if the type pool gets resized and everything + // is rehashed. + for (auto& i : type_pool_) { + for (auto& j : type_pool_) { + Type* ti = i.get(); + Type* tj = j.get(); + assert((ti == tj || !ti->IsSame(tj)) && + "Type pool contains two types that are the same."); + } + } +#endif } void TypeManager::RemoveId(uint32_t id) { @@ -105,14 +202,14 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) { uint32_t id = GetId(type); if (id != 0) return id; - std::unique_ptr typeInst; + std::unique_ptr typeInst; id = context()->TakeNextId(); RegisterType(id, *type); switch (type->kind()) { -#define DefineParameterlessCase(kind) \ - case Type::k##kind: \ - typeInst.reset(new ir::Instruction(context(), SpvOpType##kind, 0, id, \ - std::initializer_list{})); \ +#define DefineParameterlessCase(kind) \ + case Type::k##kind: \ + typeInst = MakeUnique(context(), SpvOpType##kind, 0, id, \ + std::initializer_list{}); \ break; DefineParameterlessCase(Void); DefineParameterlessCase(Bool); @@ -125,45 +222,45 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) { DefineParameterlessCase(NamedBarrier); #undef DefineParameterlessCase case Type::kInteger: - typeInst.reset(new ir::Instruction( + typeInst = MakeUnique( context(), SpvOpTypeInt, 0, id, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {type->AsInteger()->width()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, - {(type->AsInteger()->IsSigned() ? 1u : 0u)}}})); + {(type->AsInteger()->IsSigned() ? 1u : 0u)}}}); break; case Type::kFloat: - typeInst.reset(new ir::Instruction( + typeInst = MakeUnique( context(), SpvOpTypeFloat, 0, id, - std::initializer_list{ - {SPV_OPERAND_TYPE_LITERAL_INTEGER, {type->AsFloat()->width()}}})); + std::initializer_list{ + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {type->AsFloat()->width()}}}); break; case Type::kVector: { uint32_t subtype = GetTypeInstruction(type->AsVector()->element_type()); - typeInst.reset( - new ir::Instruction(context(), SpvOpTypeVector, 0, id, - std::initializer_list{ - {SPV_OPERAND_TYPE_ID, {subtype}}, - {SPV_OPERAND_TYPE_LITERAL_INTEGER, - {type->AsVector()->element_count()}}})); + typeInst = + MakeUnique(context(), SpvOpTypeVector, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {subtype}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, + {type->AsVector()->element_count()}}}); break; } case Type::kMatrix: { uint32_t subtype = GetTypeInstruction(type->AsMatrix()->element_type()); - typeInst.reset( - new ir::Instruction(context(), SpvOpTypeMatrix, 0, id, - std::initializer_list{ - {SPV_OPERAND_TYPE_ID, {subtype}}, - {SPV_OPERAND_TYPE_LITERAL_INTEGER, - {type->AsMatrix()->element_count()}}})); + typeInst = + MakeUnique(context(), SpvOpTypeMatrix, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {subtype}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, + {type->AsMatrix()->element_count()}}}); break; } case Type::kImage: { const Image* image = type->AsImage(); uint32_t subtype = GetTypeInstruction(image->sampled_type()); - typeInst.reset(new ir::Instruction( + typeInst = MakeUnique( context(), SpvOpTypeImage, 0, id, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_ID, {subtype}}, {SPV_OPERAND_TYPE_DIMENSIONALITY, {static_cast(image->dim())}}, @@ -176,45 +273,42 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) { {SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT, {static_cast(image->format())}}, {SPV_OPERAND_TYPE_ACCESS_QUALIFIER, - {static_cast(image->access_qualifier())}}})); + {static_cast(image->access_qualifier())}}}); break; } case Type::kSampledImage: { uint32_t subtype = GetTypeInstruction(type->AsSampledImage()->image_type()); - typeInst.reset( - new ir::Instruction(context(), SpvOpTypeSampledImage, 0, id, - std::initializer_list{ - {SPV_OPERAND_TYPE_ID, {subtype}}})); + typeInst = MakeUnique( + context(), SpvOpTypeSampledImage, 0, id, + std::initializer_list{{SPV_OPERAND_TYPE_ID, {subtype}}}); break; } case Type::kArray: { uint32_t subtype = GetTypeInstruction(type->AsArray()->element_type()); - typeInst.reset(new ir::Instruction( + typeInst = MakeUnique( context(), SpvOpTypeArray, 0, id, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_ID, {subtype}}, - {SPV_OPERAND_TYPE_ID, {type->AsArray()->LengthId()}}})); + {SPV_OPERAND_TYPE_ID, {type->AsArray()->LengthId()}}}); break; } case Type::kRuntimeArray: { uint32_t subtype = GetTypeInstruction(type->AsRuntimeArray()->element_type()); - typeInst.reset( - new ir::Instruction(context(), SpvOpTypeRuntimeArray, 0, id, - std::initializer_list{ - {SPV_OPERAND_TYPE_ID, {subtype}}})); + typeInst = MakeUnique( + context(), SpvOpTypeRuntimeArray, 0, id, + std::initializer_list{{SPV_OPERAND_TYPE_ID, {subtype}}}); break; } case Type::kStruct: { - std::vector ops; + std::vector ops; const Struct* structTy = type->AsStruct(); for (auto ty : structTy->element_types()) { - ops.push_back( - ir::Operand(SPV_OPERAND_TYPE_ID, {GetTypeInstruction(ty)})); + ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {GetTypeInstruction(ty)})); } - typeInst.reset( - new ir::Instruction(context(), SpvOpTypeStruct, 0, id, ops)); + typeInst = + MakeUnique(context(), SpvOpTypeStruct, 0, id, ops); break; } case Type::kOpaque: { @@ -224,51 +318,50 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) { std::vector words(size / 4 + 1, 0); char* dst = reinterpret_cast(words.data()); strncpy(dst, opaque->name().c_str(), size); - typeInst.reset( - new ir::Instruction(context(), SpvOpTypeOpaque, 0, id, - std::initializer_list{ - {SPV_OPERAND_TYPE_LITERAL_STRING, words}})); + typeInst = MakeUnique( + context(), SpvOpTypeOpaque, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_LITERAL_STRING, words}}); break; } case Type::kPointer: { const Pointer* pointer = type->AsPointer(); uint32_t subtype = GetTypeInstruction(pointer->pointee_type()); - typeInst.reset(new ir::Instruction( + typeInst = MakeUnique( context(), SpvOpTypePointer, 0, id, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_STORAGE_CLASS, {static_cast(pointer->storage_class())}}, - {SPV_OPERAND_TYPE_ID, {subtype}}})); + {SPV_OPERAND_TYPE_ID, {subtype}}}); break; } case Type::kFunction: { - std::vector ops; + std::vector ops; const Function* function = type->AsFunction(); - ops.push_back(ir::Operand(SPV_OPERAND_TYPE_ID, - {GetTypeInstruction(function->return_type())})); + ops.push_back(Operand(SPV_OPERAND_TYPE_ID, + {GetTypeInstruction(function->return_type())})); for (auto ty : function->param_types()) { - ops.push_back( - ir::Operand(SPV_OPERAND_TYPE_ID, {GetTypeInstruction(ty)})); + ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {GetTypeInstruction(ty)})); } - typeInst.reset( - new ir::Instruction(context(), SpvOpTypeFunction, 0, id, ops)); + typeInst = + MakeUnique(context(), SpvOpTypeFunction, 0, id, ops); break; } case Type::kPipe: - typeInst.reset(new ir::Instruction( + typeInst = MakeUnique( context(), SpvOpTypePipe, 0, id, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_ACCESS_QUALIFIER, - {static_cast(type->AsPipe()->access_qualifier())}}})); + {static_cast(type->AsPipe()->access_qualifier())}}}); break; case Type::kForwardPointer: - typeInst.reset(new ir::Instruction( + typeInst = MakeUnique( context(), SpvOpTypeForwardPointer, 0, 0, - std::initializer_list{ + std::initializer_list{ {SPV_OPERAND_TYPE_ID, {type->AsForwardPointer()->target_id()}}, {SPV_OPERAND_TYPE_STORAGE_CLASS, {static_cast( - type->AsForwardPointer()->storage_class())}}})); + type->AsForwardPointer()->storage_class())}}}); break; default: assert(false && "Unexpected type"); @@ -282,18 +375,17 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) { uint32_t TypeManager::FindPointerToType(uint32_t type_id, SpvStorageClass storage_class) { - opt::analysis::Type* pointeeTy = context()->get_type_mgr()->GetType(type_id); - opt::analysis::Pointer pointerTy(pointeeTy, storage_class); - if (type_id == context()->get_type_mgr()->GetId(pointeeTy)) { + Type* pointeeTy = GetType(type_id); + Pointer pointerTy(pointeeTy, storage_class); + if (pointeeTy->IsUniqueType(true)) { // Non-ambiguous type. Get the pointer type through the type manager. - return context()->get_type_mgr()->GetTypeInstruction(&pointerTy); + return GetTypeInstruction(&pointerTy); } // Ambiguous type, do a linear search. - ir::Module::inst_iterator type_itr = - context()->module()->types_values_begin(); + Module::inst_iterator type_itr = context()->module()->types_values_begin(); for (; type_itr != context()->module()->types_values_end(); ++type_itr) { - const ir::Instruction* type_inst = &*type_itr; + const Instruction* type_inst = &*type_itr; if (type_inst->opcode() == SpvOpTypePointer && type_inst->GetSingleWordOperand(kSpvTypePointerTypeIdInIdx) == type_id && @@ -304,12 +396,11 @@ uint32_t TypeManager::FindPointerToType(uint32_t type_id, // Must create the pointer type. uint32_t resultId = context()->TakeNextId(); - std::unique_ptr type_inst(new ir::Instruction( - context(), SpvOpTypePointer, 0, resultId, - {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, - {uint32_t(storage_class)}}, - {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}})); - context()->AnalyzeDefUse(type_inst.get()); + std::unique_ptr type_inst( + new Instruction(context(), SpvOpTypePointer, 0, resultId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, + {uint32_t(storage_class)}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}})); context()->AddType(std::move(type_inst)); context()->get_type_mgr()->RegisterType(resultId, pointerTy); return resultId; @@ -332,20 +423,19 @@ void TypeManager::AttachDecorations(uint32_t id, const Type* type) { void TypeManager::CreateDecoration(uint32_t target, const std::vector& decoration, uint32_t element) { - std::vector ops; - ops.push_back(ir::Operand(SPV_OPERAND_TYPE_ID, {target})); + std::vector ops; + ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {target})); if (element != 0) { - ops.push_back(ir::Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {element})); + ops.push_back(Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {element})); } - ops.push_back(ir::Operand(SPV_OPERAND_TYPE_DECORATION, {decoration[0]})); + ops.push_back(Operand(SPV_OPERAND_TYPE_DECORATION, {decoration[0]})); for (size_t i = 1; i < decoration.size(); ++i) { - ops.push_back( - ir::Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {decoration[i]})); + ops.push_back(Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {decoration[i]})); } - context()->AddAnnotationInst(MakeUnique( + context()->AddAnnotationInst(MakeUnique( context(), (element == 0 ? SpvOpDecorate : SpvOpMemberDecorate), 0, 0, ops)); - ir::Instruction* inst = &*--context()->annotation_end(); + Instruction* inst = &*--context()->annotation_end(); context()->get_def_use_mgr()->AnalyzeInstUse(inst); } @@ -360,7 +450,8 @@ Type* TypeManager::RebuildType(const Type& type) { #define DefineNoSubtypeCase(kind) \ case Type::k##kind: \ rebuilt_ty.reset(type.Clone().release()); \ - break; + return type_pool_.insert(std::move(rebuilt_ty)).first->get(); + DefineNoSubtypeCase(Void); DefineNoSubtypeCase(Bool); DefineNoSubtypeCase(Integer); @@ -378,55 +469,54 @@ Type* TypeManager::RebuildType(const Type& type) { case Type::kVector: { const Vector* vec_ty = type.AsVector(); const Type* ele_ty = vec_ty->element_type(); - rebuilt_ty.reset( - new Vector(RebuildType(*ele_ty), vec_ty->element_count())); + rebuilt_ty = + MakeUnique(RebuildType(*ele_ty), vec_ty->element_count()); break; } case Type::kMatrix: { const Matrix* mat_ty = type.AsMatrix(); const Type* ele_ty = mat_ty->element_type(); - rebuilt_ty.reset( - new Matrix(RebuildType(*ele_ty), mat_ty->element_count())); + rebuilt_ty = + MakeUnique(RebuildType(*ele_ty), mat_ty->element_count()); break; } case Type::kImage: { const Image* image_ty = type.AsImage(); const Type* ele_ty = image_ty->sampled_type(); - rebuilt_ty.reset(new Image(RebuildType(*ele_ty), image_ty->dim(), - image_ty->depth(), image_ty->is_arrayed(), - image_ty->is_multisampled(), - image_ty->sampled(), image_ty->format(), - image_ty->access_qualifier())); + rebuilt_ty = + MakeUnique(RebuildType(*ele_ty), image_ty->dim(), + image_ty->depth(), image_ty->is_arrayed(), + image_ty->is_multisampled(), image_ty->sampled(), + image_ty->format(), image_ty->access_qualifier()); break; } case Type::kSampledImage: { const SampledImage* image_ty = type.AsSampledImage(); const Type* ele_ty = image_ty->image_type(); - rebuilt_ty.reset( - - new SampledImage(RebuildType(*ele_ty))); + rebuilt_ty = MakeUnique(RebuildType(*ele_ty)); break; } case Type::kArray: { const Array* array_ty = type.AsArray(); const Type* ele_ty = array_ty->element_type(); - rebuilt_ty.reset(new Array(RebuildType(*ele_ty), array_ty->LengthId())); + rebuilt_ty = + MakeUnique(RebuildType(*ele_ty), array_ty->LengthId()); break; } case Type::kRuntimeArray: { const RuntimeArray* array_ty = type.AsRuntimeArray(); const Type* ele_ty = array_ty->element_type(); - rebuilt_ty.reset(new RuntimeArray(RebuildType(*ele_ty))); + rebuilt_ty = MakeUnique(RebuildType(*ele_ty)); break; } case Type::kStruct: { const Struct* struct_ty = type.AsStruct(); - std::vector subtypes; + std::vector subtypes; subtypes.reserve(struct_ty->element_types().size()); for (const auto* ele_ty : struct_ty->element_types()) { subtypes.push_back(RebuildType(*ele_ty)); } - rebuilt_ty.reset(new Struct(subtypes)); + rebuilt_ty = MakeUnique(subtypes); Struct* rebuilt_struct = rebuilt_ty->AsStruct(); for (auto pair : struct_ty->element_decorations()) { uint32_t index = pair.first; @@ -441,25 +531,25 @@ Type* TypeManager::RebuildType(const Type& type) { case Type::kPointer: { const Pointer* pointer_ty = type.AsPointer(); const Type* ele_ty = pointer_ty->pointee_type(); - rebuilt_ty.reset( - new Pointer(RebuildType(*ele_ty), pointer_ty->storage_class())); + rebuilt_ty = MakeUnique(RebuildType(*ele_ty), + pointer_ty->storage_class()); break; } case Type::kFunction: { const Function* function_ty = type.AsFunction(); const Type* ret_ty = function_ty->return_type(); - std::vector param_types; + std::vector param_types; param_types.reserve(function_ty->param_types().size()); for (const auto* param_ty : function_ty->param_types()) { param_types.push_back(RebuildType(*param_ty)); } - rebuilt_ty.reset(new Function(RebuildType(*ret_ty), param_types)); + rebuilt_ty = MakeUnique(RebuildType(*ret_ty), param_types); break; } case Type::kForwardPointer: { const ForwardPointer* forward_ptr_ty = type.AsForwardPointer(); - rebuilt_ty.reset(new ForwardPointer(forward_ptr_ty->target_id(), - forward_ptr_ty->storage_class())); + rebuilt_ty = MakeUnique(forward_ptr_ty->target_id(), + forward_ptr_ty->storage_class()); const Pointer* target_ptr = forward_ptr_ty->target_pointer(); if (target_ptr) { rebuilt_ty->AsForwardPointer()->SetTargetPointer( @@ -496,9 +586,8 @@ Type* TypeManager::GetRegisteredType(const Type* type) { return GetType(id); } -Type* TypeManager::RecordIfTypeDefinition( - const spvtools::ir::Instruction& inst) { - if (!spvtools::ir::IsTypeInst(inst.opcode())) return nullptr; +Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) { + if (!IsTypeInst(inst.opcode())) return nullptr; Type* type = nullptr; switch (inst.opcode()) { @@ -544,42 +633,79 @@ Type* TypeManager::RecordIfTypeDefinition( case SpvOpTypeArray: type = new Array(GetType(inst.GetSingleWordInOperand(0)), inst.GetSingleWordInOperand(1)); + if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) { + incomplete_types_.emplace_back(inst.result_id(), type); + id_to_incomplete_type_[inst.result_id()] = type; + return type; + } break; case SpvOpTypeRuntimeArray: type = new RuntimeArray(GetType(inst.GetSingleWordInOperand(0))); + if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) { + incomplete_types_.emplace_back(inst.result_id(), type); + id_to_incomplete_type_[inst.result_id()] = type; + return type; + } break; case SpvOpTypeStruct: { - std::vector element_types; + std::vector element_types; + bool incomplete_type = false; for (uint32_t i = 0; i < inst.NumInOperands(); ++i) { - element_types.push_back(GetType(inst.GetSingleWordInOperand(i))); + uint32_t type_id = inst.GetSingleWordInOperand(i); + element_types.push_back(GetType(type_id)); + if (id_to_incomplete_type_.count(type_id)) { + incomplete_type = true; + } } type = new Struct(element_types); + + if (incomplete_type) { + incomplete_types_.emplace_back(inst.result_id(), type); + id_to_incomplete_type_[inst.result_id()] = type; + return type; + } } break; case SpvOpTypeOpaque: { const uint32_t* data = inst.GetInOperand(0).words.data(); type = new Opaque(reinterpret_cast(data)); } break; case SpvOpTypePointer: { - auto* ptr = new Pointer( - GetType(inst.GetSingleWordInOperand(1)), + uint32_t pointee_type_id = inst.GetSingleWordInOperand(1); + type = new Pointer( + GetType(pointee_type_id), static_cast(inst.GetSingleWordInOperand(0))); - // Let's see if somebody forward references this pointer. - for (auto* fp : unresolved_forward_pointers_) { - if (fp->target_id() == inst.result_id()) { - fp->SetTargetPointer(ptr); - unresolved_forward_pointers_.erase(fp); - break; - } + + if (id_to_incomplete_type_.count(pointee_type_id)) { + incomplete_types_.emplace_back(inst.result_id(), type); + id_to_incomplete_type_[inst.result_id()] = type; + return type; } - type = ptr; + id_to_incomplete_type_.erase(inst.result_id()); + } break; case SpvOpTypeFunction: { - Type* return_type = GetType(inst.GetSingleWordInOperand(0)); - std::vector param_types; - for (uint32_t i = 1; i < inst.NumInOperands(); ++i) { - param_types.push_back(GetType(inst.GetSingleWordInOperand(i))); + bool incomplete_type = false; + uint32_t return_type_id = inst.GetSingleWordInOperand(0); + if (id_to_incomplete_type_.count(return_type_id)) { + incomplete_type = true; } + Type* return_type = GetType(return_type_id); + std::vector param_types; + for (uint32_t i = 1; i < inst.NumInOperands(); ++i) { + uint32_t param_type_id = inst.GetSingleWordInOperand(i); + param_types.push_back(GetType(param_type_id)); + if (id_to_incomplete_type_.count(param_type_id)) { + incomplete_type = true; + } + } + type = new Function(return_type, param_types); + + if (incomplete_type) { + incomplete_types_.emplace_back(inst.result_id(), type); + id_to_incomplete_type_[inst.result_id()] = type; + return type; + } } break; case SpvOpTypeEvent: type = new Event(); @@ -599,12 +725,12 @@ Type* TypeManager::RecordIfTypeDefinition( break; case SpvOpTypeForwardPointer: { // Handling of forward pointers is different from the other types. - auto* fp = new ForwardPointer( - inst.GetSingleWordInOperand(0), - static_cast(inst.GetSingleWordInOperand(1))); - forward_pointers_.emplace_back(fp); - unresolved_forward_pointers_.insert(fp); - return fp; + uint32_t target_id = inst.GetSingleWordInOperand(0); + type = new ForwardPointer(target_id, static_cast( + inst.GetSingleWordInOperand(1))); + incomplete_types_.emplace_back(target_id, type); + id_to_incomplete_type_[target_id] = type; + return type; } case SpvOpTypePipeStorage: type = new PipeStorage(); @@ -618,28 +744,24 @@ Type* TypeManager::RecordIfTypeDefinition( } uint32_t id = inst.result_id(); - if (id == 0) { - SPIRV_ASSERT(consumer_, inst.opcode() == SpvOpTypeForwardPointer, - "instruction without result id found"); - } else { - SPIRV_ASSERT(consumer_, type != nullptr, - "type should not be nullptr at this point"); - std::vector decorations = - context()->get_decoration_mgr()->GetDecorationsFor(id, true); - for (auto dec : decorations) { - AttachDecoration(*dec, type); - } - std::unique_ptr unique(type); - auto pair = type_pool_.insert(std::move(unique)); - id_to_type_[id] = pair.first->get(); - type_to_id_[pair.first->get()] = id; + SPIRV_ASSERT(consumer_, id != 0, "instruction without result id found"); + SPIRV_ASSERT(consumer_, type != nullptr, + "type should not be nullptr at this point"); + std::vector decorations = + context()->get_decoration_mgr()->GetDecorationsFor(id, true); + for (auto dec : decorations) { + AttachDecoration(*dec, type); } + std::unique_ptr unique(type); + auto pair = type_pool_.insert(std::move(unique)); + id_to_type_[id] = pair.first->get(); + type_to_id_[pair.first->get()] = id; return type; } -void TypeManager::AttachDecoration(const ir::Instruction& inst, Type* type) { +void TypeManager::AttachDecoration(const Instruction& inst, Type* type) { const SpvOp opcode = inst.opcode(); - if (!ir::IsAnnotationInst(opcode)) return; + if (!IsAnnotationInst(opcode)) return; switch (opcode) { case SpvOpDecorate: { @@ -672,16 +794,16 @@ void TypeManager::AttachDecoration(const ir::Instruction& inst, Type* type) { const Type* TypeManager::GetMemberType( const Type* parent_type, const std::vector& access_chain) { for (uint32_t element_index : access_chain) { - if (const analysis::Struct* struct_type = parent_type->AsStruct()) { + if (const Struct* struct_type = parent_type->AsStruct()) { parent_type = struct_type->element_types()[element_index]; - } else if (const analysis::Array* array_type = parent_type->AsArray()) { + } else if (const Array* array_type = parent_type->AsArray()) { parent_type = array_type->element_type(); - } else if (const analysis::RuntimeArray* runtime_array_type = + } else if (const RuntimeArray* runtime_array_type = parent_type->AsRuntimeArray()) { parent_type = runtime_array_type->element_type(); - } else if (const analysis::Vector* vector_type = parent_type->AsVector()) { + } else if (const Vector* vector_type = parent_type->AsVector()) { parent_type = vector_type->element_type(); - } else if (const analysis::Matrix* matrix_type = parent_type->AsMatrix()) { + } else if (const Matrix* matrix_type = parent_type->AsMatrix()) { parent_type = matrix_type->element_type(); } else { assert(false && "Trying to get a member of a type without members."); @@ -690,6 +812,115 @@ const Type* TypeManager::GetMemberType( return parent_type; } +void TypeManager::ReplaceForwardPointers(Type* type) { + switch (type->kind()) { + case Type::kArray: { + const ForwardPointer* element_type = + type->AsArray()->element_type()->AsForwardPointer(); + if (element_type) { + type->AsArray()->ReplaceElementType(element_type->target_pointer()); + } + } break; + case Type::kRuntimeArray: { + const ForwardPointer* element_type = + type->AsRuntimeArray()->element_type()->AsForwardPointer(); + if (element_type) { + type->AsRuntimeArray()->ReplaceElementType( + element_type->target_pointer()); + } + } break; + case Type::kStruct: { + auto& member_types = type->AsStruct()->element_types(); + for (auto& member_type : member_types) { + if (member_type->AsForwardPointer()) { + member_type = member_type->AsForwardPointer()->target_pointer(); + assert(member_type); + } + } + } break; + case Type::kPointer: { + const ForwardPointer* pointee_type = + type->AsPointer()->pointee_type()->AsForwardPointer(); + if (pointee_type) { + type->AsPointer()->SetPointeeType(pointee_type->target_pointer()); + } + } break; + case Type::kFunction: { + Function* func_type = type->AsFunction(); + const ForwardPointer* return_type = + func_type->return_type()->AsForwardPointer(); + if (return_type) { + func_type->SetReturnType(return_type->target_pointer()); + } + + auto& param_types = func_type->param_types(); + for (auto& param_type : param_types) { + if (param_type->AsForwardPointer()) { + param_type = param_type->AsForwardPointer()->target_pointer(); + } + } + } break; + default: + break; + } +} + +void TypeManager::ReplaceType(Type* new_type, Type* original_type) { + assert(original_type->kind() == new_type->kind() && + "Types must be the same for replacement.\n"); + for (auto& p : incomplete_types_) { + Type* type = p.type(); + if (!type) { + continue; + } + + switch (type->kind()) { + case Type::kArray: { + const Type* element_type = type->AsArray()->element_type(); + if (element_type == original_type) { + type->AsArray()->ReplaceElementType(new_type); + } + } break; + case Type::kRuntimeArray: { + const Type* element_type = type->AsRuntimeArray()->element_type(); + if (element_type == original_type) { + type->AsRuntimeArray()->ReplaceElementType(new_type); + } + } break; + case Type::kStruct: { + auto& member_types = type->AsStruct()->element_types(); + for (auto& member_type : member_types) { + if (member_type == original_type) { + member_type = new_type; + } + } + } break; + case Type::kPointer: { + const Type* pointee_type = type->AsPointer()->pointee_type(); + if (pointee_type == original_type) { + type->AsPointer()->SetPointeeType(new_type); + } + } break; + case Type::kFunction: { + Function* func_type = type->AsFunction(); + const Type* return_type = func_type->return_type(); + if (return_type == original_type) { + func_type->SetReturnType(new_type); + } + + auto& param_types = func_type->param_types(); + for (auto& param_type : param_types) { + if (param_type == original_type) { + param_type = new_type; + } + } + } break; + default: + break; + } + } +} + } // namespace analysis } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/type_manager.h b/3rdparty/spirv-tools/source/opt/type_manager.h index 2020d8b9a..c44969e84 100644 --- a/3rdparty/spirv-tools/source/opt/type_manager.h +++ b/3rdparty/spirv-tools/source/opt/type_manager.h @@ -12,23 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_TYPE_MANAGER_H_ -#define LIBSPIRV_OPT_TYPE_MANAGER_H_ +#ifndef SOURCE_OPT_TYPE_MANAGER_H_ +#define SOURCE_OPT_TYPE_MANAGER_H_ #include #include #include +#include #include -#include "module.h" +#include "source/opt/module.h" +#include "source/opt/types.h" #include "spirv-tools/libspirv.hpp" -#include "types.h" namespace spvtools { -namespace ir { -class IRContext; -} // namespace ir namespace opt { + +class IRContext; + namespace analysis { // Hashing functor. @@ -75,7 +76,7 @@ class TypeManager { // will be communicated to the outside via the given message |consumer|. // This instance only keeps a reference to the |consumer|, so the |consumer| // should outlive this instance. - TypeManager(const MessageConsumer& consumer, spvtools::ir::IRContext* c); + TypeManager(const MessageConsumer& consumer, IRContext* c); TypeManager(const TypeManager&) = delete; TypeManager(TypeManager&&) = delete; @@ -94,11 +95,6 @@ class TypeManager { IdToTypeMap::const_iterator begin() const { return id_to_type_.cbegin(); } IdToTypeMap::const_iterator end() const { return id_to_type_.cend(); } - // Returns the forward pointer type at the given |index|. - ForwardPointer* GetForwardPointer(uint32_t index) const; - // Returns the number of forward pointer types hold in this manager. - size_t NumForwardPointers() const { return forward_pointers_.size(); } - // Returns a pair of the type and pointer to the type in |sc|. // // |id| must be a registered type. @@ -146,14 +142,31 @@ class TypeManager { using TypePool = std::unordered_set, HashTypeUniquePointer, CompareTypeUniquePointers>; - using ForwardPointerVector = std::vector>; + + class UnresolvedType { + public: + UnresolvedType(uint32_t i, Type* t) : id_(i), type_(t) {} + UnresolvedType(const UnresolvedType&) = delete; + UnresolvedType(UnresolvedType&& that) + : id_(that.id_), type_(std::move(that.type_)) {} + + uint32_t id() { return id_; } + Type* type() { return type_.get(); } + std::unique_ptr&& ReleaseType() { return std::move(type_); } + void ResetType(Type* t) { type_.reset(t); } + + private: + uint32_t id_; + std::unique_ptr type_; + }; + using IdToUnresolvedType = std::vector; // Analyzes the types and decorations on types in the given |module|. - void AnalyzeTypes(const spvtools::ir::Module& module); + void AnalyzeTypes(const Module& module); - spvtools::ir::IRContext* context() { return context_; } + IRContext* context() { return context_; } - // Attachs the decorations on |type| to |id|. + // Attaches the decorations on |type| to |id|. void AttachDecorations(uint32_t id, const Type* type); // Create the annotation instruction. @@ -166,30 +179,40 @@ class TypeManager { // Creates and returns a type from the given SPIR-V |inst|. Returns nullptr if // the given instruction is not for defining a type. - Type* RecordIfTypeDefinition(const spvtools::ir::Instruction& inst); + Type* RecordIfTypeDefinition(const Instruction& inst); // Attaches the decoration encoded in |inst| to |type|. Does nothing if the // given instruction is not a decoration instruction. Assumes the target is // |type| (e.g. should be called in loop of |type|'s decorations). - void AttachDecoration(const spvtools::ir::Instruction& inst, Type* type); + void AttachDecoration(const Instruction& inst, Type* type); // Returns an equivalent pointer to |type| built in terms of pointers owned by // |type_pool_|. For example, if |type| is a vec3 of bool, it will be rebuilt // replacing the bool subtype with one owned by |type_pool_|. Type* RebuildType(const Type& type); + // Completes the incomplete type |type|, by replaces all references to + // ForwardPointer by the defining Pointer. + void ReplaceForwardPointers(Type* type); + + // Replaces all references to |original_type| in |incomplete_types_| by + // |new_type|. + void ReplaceType(Type* new_type, Type* original_type); + const MessageConsumer& consumer_; // Message consumer. - spvtools::ir::IRContext* context_; + IRContext* context_; IdToTypeMap id_to_type_; // Mapping from ids to their type representations. TypeToIdMap type_to_id_; // Mapping from types to their defining ids. TypePool type_pool_; // Memory owner of type pointers. - ForwardPointerVector forward_pointers_; // All forward pointer declarations. - // All unresolved forward pointer declarations. - // Refers the contents in the above vector. - std::unordered_set unresolved_forward_pointers_; + IdToUnresolvedType incomplete_types_; // All incomplete types. Stored in an + // std::vector to make traversals + // deterministic. + + IdToTypeMap id_to_incomplete_type_; // Maps ids to their type representations + // for incomplete types. }; } // namespace analysis } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_TYPE_MANAGER_H_ +#endif // SOURCE_OPT_TYPE_MANAGER_H_ diff --git a/3rdparty/spirv-tools/source/opt/types.cpp b/3rdparty/spirv-tools/source/opt/types.cpp index 4838f8e50..15cff5486 100644 --- a/3rdparty/spirv-tools/source/opt/types.cpp +++ b/3rdparty/spirv-tools/source/opt/types.cpp @@ -16,8 +16,10 @@ #include #include #include +#include -#include "types.h" +#include "source/opt/types.h" +#include "source/util/make_unique.h" namespace spvtools { namespace opt { @@ -94,9 +96,9 @@ bool Type::IsUniqueType(bool allowVariablePointers) const { std::unique_ptr Type::Clone() const { std::unique_ptr type; switch (kind_) { -#define DeclareKindCase(kind) \ - case k##kind: \ - type.reset(new kind(*this->As##kind())); \ +#define DeclareKindCase(kind) \ + case k##kind: \ + type = MakeUnique(*this->As##kind()); \ break; DeclareKindCase(Void); DeclareKindCase(Bool); @@ -171,7 +173,12 @@ bool Type::operator==(const Type& other) const { } } -void Type::GetHashWords(std::vector* words) const { +void Type::GetHashWords(std::vector* words, + std::unordered_set* seen) const { + if (!seen->insert(this).second) { + return; + } + words->push_back(kind_); for (const auto& d : decorations_) { for (auto w : d) { @@ -180,9 +187,9 @@ void Type::GetHashWords(std::vector* words) const { } switch (kind_) { -#define DeclareKindCase(type) \ - case k##type: \ - As##type()->GetExtraHashWords(words); \ +#define DeclareKindCase(type) \ + case k##type: \ + As##type()->GetExtraHashWords(words, seen); \ break; DeclareKindCase(Void); DeclareKindCase(Bool); @@ -212,6 +219,8 @@ void Type::GetHashWords(std::vector* words) const { assert(false && "Unhandled type"); break; } + + seen->erase(this); } size_t Type::HashValue() const { @@ -225,7 +234,7 @@ size_t Type::HashValue() const { return std::hash()(h); } -bool Integer::IsSame(const Type* that) const { +bool Integer::IsSameImpl(const Type* that, IsSameCache*) const { const Integer* it = that->AsInteger(); return it && width_ == it->width_ && signed_ == it->signed_ && HasSameDecorations(that); @@ -237,12 +246,13 @@ std::string Integer::str() const { return oss.str(); } -void Integer::GetExtraHashWords(std::vector* words) const { +void Integer::GetExtraHashWords(std::vector* words, + std::unordered_set*) const { words->push_back(width_); words->push_back(signed_); } -bool Float::IsSame(const Type* that) const { +bool Float::IsSameImpl(const Type* that, IsSameCache*) const { const Float* ft = that->AsFloat(); return ft && width_ == ft->width_ && HasSameDecorations(that); } @@ -253,7 +263,8 @@ std::string Float::str() const { return oss.str(); } -void Float::GetExtraHashWords(std::vector* words) const { +void Float::GetExtraHashWords(std::vector* words, + std::unordered_set*) const { words->push_back(width_); } @@ -262,10 +273,11 @@ Vector::Vector(Type* type, uint32_t count) assert(type->AsBool() || type->AsInteger() || type->AsFloat()); } -bool Vector::IsSame(const Type* that) const { +bool Vector::IsSameImpl(const Type* that, IsSameCache* seen) const { const Vector* vt = that->AsVector(); if (!vt) return false; - return count_ == vt->count_ && element_type_->IsSame(vt->element_type_) && + return count_ == vt->count_ && + element_type_->IsSameImpl(vt->element_type_, seen) && HasSameDecorations(that); } @@ -275,8 +287,9 @@ std::string Vector::str() const { return oss.str(); } -void Vector::GetExtraHashWords(std::vector* words) const { - element_type_->GetHashWords(words); +void Vector::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { + element_type_->GetHashWords(words, seen); words->push_back(count_); } @@ -285,10 +298,11 @@ Matrix::Matrix(Type* type, uint32_t count) assert(type->AsVector()); } -bool Matrix::IsSame(const Type* that) const { +bool Matrix::IsSameImpl(const Type* that, IsSameCache* seen) const { const Matrix* mt = that->AsMatrix(); if (!mt) return false; - return count_ == mt->count_ && element_type_->IsSame(mt->element_type_) && + return count_ == mt->count_ && + element_type_->IsSameImpl(mt->element_type_, seen) && HasSameDecorations(that); } @@ -298,8 +312,9 @@ std::string Matrix::str() const { return oss.str(); } -void Matrix::GetExtraHashWords(std::vector* words) const { - element_type_->GetHashWords(words); +void Matrix::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { + element_type_->GetHashWords(words, seen); words->push_back(count_); } @@ -317,13 +332,14 @@ Image::Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample, // TODO(antiagainst): check sampled_type } -bool Image::IsSame(const Type* that) const { +bool Image::IsSameImpl(const Type* that, IsSameCache* seen) const { const Image* it = that->AsImage(); if (!it) return false; return dim_ == it->dim_ && depth_ == it->depth_ && arrayed_ == it->arrayed_ && ms_ == it->ms_ && sampled_ == it->sampled_ && format_ == it->format_ && access_qualifier_ == it->access_qualifier_ && - sampled_type_->IsSame(it->sampled_type_) && HasSameDecorations(that); + sampled_type_->IsSameImpl(it->sampled_type_, seen) && + HasSameDecorations(that); } std::string Image::str() const { @@ -334,8 +350,9 @@ std::string Image::str() const { return oss.str(); } -void Image::GetExtraHashWords(std::vector* words) const { - sampled_type_->GetHashWords(words); +void Image::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { + sampled_type_->GetHashWords(words, seen); words->push_back(dim_); words->push_back(depth_); words->push_back(arrayed_); @@ -345,10 +362,11 @@ void Image::GetExtraHashWords(std::vector* words) const { words->push_back(access_qualifier_); } -bool SampledImage::IsSame(const Type* that) const { +bool SampledImage::IsSameImpl(const Type* that, IsSameCache* seen) const { const SampledImage* sit = that->AsSampledImage(); if (!sit) return false; - return image_type_->IsSame(sit->image_type_) && HasSameDecorations(that); + return image_type_->IsSameImpl(sit->image_type_, seen) && + HasSameDecorations(that); } std::string SampledImage::str() const { @@ -357,8 +375,9 @@ std::string SampledImage::str() const { return oss.str(); } -void SampledImage::GetExtraHashWords(std::vector* words) const { - image_type_->GetHashWords(words); +void SampledImage::GetExtraHashWords( + std::vector* words, std::unordered_set* seen) const { + image_type_->GetHashWords(words, seen); } Array::Array(Type* type, uint32_t length_id) @@ -366,11 +385,12 @@ Array::Array(Type* type, uint32_t length_id) assert(!type->AsVoid()); } -bool Array::IsSame(const Type* that) const { +bool Array::IsSameImpl(const Type* that, IsSameCache* seen) const { const Array* at = that->AsArray(); if (!at) return false; return length_id_ == at->length_id_ && - element_type_->IsSame(at->element_type_) && HasSameDecorations(that); + element_type_->IsSameImpl(at->element_type_, seen) && + HasSameDecorations(that); } std::string Array::str() const { @@ -379,20 +399,24 @@ std::string Array::str() const { return oss.str(); } -void Array::GetExtraHashWords(std::vector* words) const { - element_type_->GetHashWords(words); +void Array::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { + element_type_->GetHashWords(words, seen); words->push_back(length_id_); } +void Array::ReplaceElementType(const Type* type) { element_type_ = type; } + RuntimeArray::RuntimeArray(Type* type) : Type(kRuntimeArray), element_type_(type) { assert(!type->AsVoid()); } -bool RuntimeArray::IsSame(const Type* that) const { +bool RuntimeArray::IsSameImpl(const Type* that, IsSameCache* seen) const { const RuntimeArray* rat = that->AsRuntimeArray(); if (!rat) return false; - return element_type_->IsSame(rat->element_type_) && HasSameDecorations(that); + return element_type_->IsSameImpl(rat->element_type_, seen) && + HasSameDecorations(that); } std::string RuntimeArray::str() const { @@ -401,11 +425,16 @@ std::string RuntimeArray::str() const { return oss.str(); } -void RuntimeArray::GetExtraHashWords(std::vector* words) const { - element_type_->GetHashWords(words); +void RuntimeArray::GetExtraHashWords( + std::vector* words, std::unordered_set* seen) const { + element_type_->GetHashWords(words, seen); } -Struct::Struct(const std::vector& types) +void RuntimeArray::ReplaceElementType(const Type* type) { + element_type_ = type; +} + +Struct::Struct(const std::vector& types) : Type(kStruct), element_types_(types) { for (const auto* t : types) { (void)t; @@ -423,7 +452,7 @@ void Struct::AddMemberDecoration(uint32_t index, element_decorations_[index].push_back(std::move(decoration)); } -bool Struct::IsSame(const Type* that) const { +bool Struct::IsSameImpl(const Type* that, IsSameCache* seen) const { const Struct* st = that->AsStruct(); if (!st) return false; if (element_types_.size() != st->element_types_.size()) return false; @@ -432,7 +461,8 @@ bool Struct::IsSame(const Type* that) const { if (!HasSameDecorations(that)) return false; for (size_t i = 0; i < element_types_.size(); ++i) { - if (!element_types_[i]->IsSame(st->element_types_[i])) return false; + if (!element_types_[i]->IsSameImpl(st->element_types_[i], seen)) + return false; } for (const auto& p : element_decorations_) { if (st->element_decorations_.count(p.first) == 0) return false; @@ -454,9 +484,10 @@ std::string Struct::str() const { return oss.str(); } -void Struct::GetExtraHashWords(std::vector* words) const { +void Struct::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { for (auto* t : element_types_) { - t->GetHashWords(words); + t->GetHashWords(words, seen); } for (const auto& pair : element_decorations_) { words->push_back(pair.first); @@ -468,7 +499,7 @@ void Struct::GetExtraHashWords(std::vector* words) const { } } -bool Opaque::IsSame(const Type* that) const { +bool Opaque::IsSameImpl(const Type* that, IsSameCache*) const { const Opaque* ot = that->AsOpaque(); if (!ot) return false; return name_ == ot->name_ && HasSameDecorations(that); @@ -480,7 +511,8 @@ std::string Opaque::str() const { return oss.str(); } -void Opaque::GetExtraHashWords(std::vector* words) const { +void Opaque::GetExtraHashWords(std::vector* words, + std::unordered_set*) const { for (auto c : name_) { words->push_back(static_cast(c)); } @@ -489,22 +521,33 @@ void Opaque::GetExtraHashWords(std::vector* words) const { Pointer::Pointer(const Type* type, SpvStorageClass sc) : Type(kPointer), pointee_type_(type), storage_class_(sc) {} -bool Pointer::IsSame(const Type* that) const { +bool Pointer::IsSameImpl(const Type* that, IsSameCache* seen) const { const Pointer* pt = that->AsPointer(); if (!pt) return false; if (storage_class_ != pt->storage_class_) return false; - if (!pointee_type_->IsSame(pt->pointee_type_)) return false; + auto p = seen->insert(std::make_pair(this, that->AsPointer())); + if (!p.second) { + return true; + } + bool same_pointee = pointee_type_->IsSameImpl(pt->pointee_type_, seen); + seen->erase(p.first); + if (!same_pointee) { + return false; + } return HasSameDecorations(that); } std::string Pointer::str() const { return pointee_type_->str() + "*"; } -void Pointer::GetExtraHashWords(std::vector* words) const { - pointee_type_->GetHashWords(words); +void Pointer::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { + pointee_type_->GetHashWords(words, seen); words->push_back(storage_class_); } -Function::Function(Type* ret_type, const std::vector& params) +void Pointer::SetPointeeType(const Type* type) { pointee_type_ = type; } + +Function::Function(Type* ret_type, const std::vector& params) : Type(kFunction), return_type_(ret_type), param_types_(params) { for (auto* t : params) { (void)t; @@ -512,13 +555,13 @@ Function::Function(Type* ret_type, const std::vector& params) } } -bool Function::IsSame(const Type* that) const { +bool Function::IsSameImpl(const Type* that, IsSameCache* seen) const { const Function* ft = that->AsFunction(); if (!ft) return false; - if (!return_type_->IsSame(ft->return_type_)) return false; + if (!return_type_->IsSameImpl(ft->return_type_, seen)) return false; if (param_types_.size() != ft->param_types_.size()) return false; for (size_t i = 0; i < param_types_.size(); ++i) { - if (!param_types_[i]->IsSame(ft->param_types_[i])) return false; + if (!param_types_[i]->IsSameImpl(ft->param_types_[i], seen)) return false; } return HasSameDecorations(that); } @@ -535,14 +578,17 @@ std::string Function::str() const { return oss.str(); } -void Function::GetExtraHashWords(std::vector* words) const { - return_type_->GetHashWords(words); +void Function::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { + return_type_->GetHashWords(words, seen); for (const auto* t : param_types_) { - t->GetHashWords(words); + t->GetHashWords(words, seen); } } -bool Pipe::IsSame(const Type* that) const { +void Function::SetReturnType(const Type* type) { return_type_ = type; } + +bool Pipe::IsSameImpl(const Type* that, IsSameCache*) const { const Pipe* pt = that->AsPipe(); if (!pt) return false; return access_qualifier_ == pt->access_qualifier_ && HasSameDecorations(that); @@ -554,11 +600,12 @@ std::string Pipe::str() const { return oss.str(); } -void Pipe::GetExtraHashWords(std::vector* words) const { +void Pipe::GetExtraHashWords(std::vector* words, + std::unordered_set*) const { words->push_back(access_qualifier_); } -bool ForwardPointer::IsSame(const Type* that) const { +bool ForwardPointer::IsSameImpl(const Type* that, IsSameCache*) const { const ForwardPointer* fpt = that->AsForwardPointer(); if (!fpt) return false; return target_id_ == fpt->target_id_ && @@ -577,10 +624,11 @@ std::string ForwardPointer::str() const { return oss.str(); } -void ForwardPointer::GetExtraHashWords(std::vector* words) const { +void ForwardPointer::GetExtraHashWords( + std::vector* words, std::unordered_set* seen) const { words->push_back(target_id_); words->push_back(storage_class_); - if (pointer_) pointer_->GetHashWords(words); + if (pointer_) pointer_->GetHashWords(words, seen); } } // namespace analysis diff --git a/3rdparty/spirv-tools/source/opt/types.h b/3rdparty/spirv-tools/source/opt/types.h index ee81b769b..625f342a6 100644 --- a/3rdparty/spirv-tools/source/opt/types.h +++ b/3rdparty/spirv-tools/source/opt/types.h @@ -14,16 +14,19 @@ // This file provides a class hierarchy for representing SPIR-V types. -#ifndef LIBSPIRV_OPT_TYPES_H_ -#define LIBSPIRV_OPT_TYPES_H_ +#ifndef SOURCE_OPT_TYPES_H_ +#define SOURCE_OPT_TYPES_H_ #include #include +#include #include #include +#include +#include #include -#include "latest_version_spirv_header.h" +#include "source/latest_version_spirv_header.h" #include "spirv-tools/libspirv.h" namespace spvtools { @@ -58,6 +61,8 @@ class NamedBarrier; // which is used as a way to probe the actual . class Type { public: + typedef std::set> IsSameCache; + // Available subtypes. // // When adding a new derived class of Type, please add an entry to the enum. @@ -101,7 +106,16 @@ class Type { bool HasSameDecorations(const Type* that) const; // Returns true if this type is exactly the same as |that| type, including // decorations. - virtual bool IsSame(const Type* that) const = 0; + bool IsSame(const Type* that) const { + IsSameCache seen; + return IsSameImpl(that, &seen); + } + + // Returns true if this type is exactly the same as |that| type, including + // decorations. |seen| is the set of |Pointer*| pair that are currently being + // compared in a parent call to |IsSameImpl|. + virtual bool IsSameImpl(const Type* that, IsSameCache* seen) const = 0; + // Returns a human-readable string to represent this type. virtual std::string str() const = 0; @@ -164,11 +178,20 @@ class Type { size_t HashValue() const; // Adds the necessary words to compute a hash value of this type to |words|. - void GetHashWords(std::vector* words) const; + void GetHashWords(std::vector* words) const { + std::unordered_set seen; + GetHashWords(words, &seen); + } + + // Adds the necessary words to compute a hash value of this type to |words|. + void GetHashWords(std::vector* words, + std::unordered_set* seen) const; // Adds necessary extra words for a subtype to calculate a hash value into // |words|. - virtual void GetExtraHashWords(std::vector* words) const = 0; + virtual void GetExtraHashWords( + std::vector* words, + std::unordered_set* pSet) const = 0; protected: // Decorations attached to this type. Each decoration is encoded as a vector @@ -190,7 +213,6 @@ class Integer : public Type { : Type(kInteger), width_(w), signed_(is_signed) {} Integer(const Integer&) = default; - bool IsSame(const Type* that) const override; std::string str() const override; Integer* AsInteger() override { return this; } @@ -198,9 +220,12 @@ class Integer : public Type { uint32_t width() const { return width_; } bool IsSigned() const { return signed_; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + uint32_t width_; // bit width bool signed_; // true if this integer is signed }; @@ -210,16 +235,18 @@ class Float : public Type { Float(uint32_t w) : Type(kFloat), width_(w) {} Float(const Float&) = default; - bool IsSame(const Type* that) const override; std::string str() const override; Float* AsFloat() override { return this; } const Float* AsFloat() const override { return this; } uint32_t width() const { return width_; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + uint32_t width_; // bit width }; @@ -228,7 +255,6 @@ class Vector : public Type { Vector(Type* element_type, uint32_t count); Vector(const Vector&) = default; - bool IsSame(const Type* that) const override; std::string str() const override; const Type* element_type() const { return element_type_; } uint32_t element_count() const { return count_; } @@ -236,10 +262,13 @@ class Vector : public Type { Vector* AsVector() override { return this; } const Vector* AsVector() const override { return this; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; private: - Type* element_type_; + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const Type* element_type_; uint32_t count_; }; @@ -248,7 +277,6 @@ class Matrix : public Type { Matrix(Type* element_type, uint32_t count); Matrix(const Matrix&) = default; - bool IsSame(const Type* that) const override; std::string str() const override; const Type* element_type() const { return element_type_; } uint32_t element_count() const { return count_; } @@ -256,10 +284,13 @@ class Matrix : public Type { Matrix* AsMatrix() override { return this; } const Matrix* AsMatrix() const override { return this; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; private: - Type* element_type_; + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const Type* element_type_; uint32_t count_; }; @@ -270,7 +301,6 @@ class Image : public Type { SpvAccessQualifier qualifier = SpvAccessQualifierReadOnly); Image(const Image&) = default; - bool IsSame(const Type* that) const override; std::string str() const override; Image* AsImage() override { return this; } @@ -285,9 +315,12 @@ class Image : public Type { SpvImageFormat format() const { return format_; } SpvAccessQualifier access_qualifier() const { return access_qualifier_; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + Type* sampled_type_; SpvDim dim_; uint32_t depth_; @@ -303,7 +336,6 @@ class SampledImage : public Type { SampledImage(Type* image) : Type(kSampledImage), image_type_(image) {} SampledImage(const SampledImage&) = default; - bool IsSame(const Type* that) const override; std::string str() const override; SampledImage* AsSampledImage() override { return this; } @@ -311,9 +343,11 @@ class SampledImage : public Type { const Type* image_type() const { return image_type_; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; Type* image_type_; }; @@ -322,7 +356,6 @@ class Array : public Type { Array(Type* element_type, uint32_t length_id); Array(const Array&) = default; - bool IsSame(const Type* that) const override; std::string str() const override; const Type* element_type() const { return element_type_; } uint32_t LengthId() const { return length_id_; } @@ -330,10 +363,15 @@ class Array : public Type { Array* AsArray() override { return this; } const Array* AsArray() const override { return this; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + void ReplaceElementType(const Type* element_type); private: - Type* element_type_; + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const Type* element_type_; uint32_t length_id_; }; @@ -342,31 +380,37 @@ class RuntimeArray : public Type { RuntimeArray(Type* element_type); RuntimeArray(const RuntimeArray&) = default; - bool IsSame(const Type* that) const override; std::string str() const override; const Type* element_type() const { return element_type_; } RuntimeArray* AsRuntimeArray() override { return this; } const RuntimeArray* AsRuntimeArray() const override { return this; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + void ReplaceElementType(const Type* element_type); private: - Type* element_type_; + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const Type* element_type_; }; class Struct : public Type { public: - Struct(const std::vector& element_types); + Struct(const std::vector& element_types); Struct(const Struct&) = default; // Adds a decoration to the member at the given index. The first word is the // decoration enum, and the remaining words, if any, are its operands. void AddMemberDecoration(uint32_t index, std::vector&& decoration); - bool IsSame(const Type* that) const override; std::string str() const override; - const std::vector& element_types() const { return element_types_; } + const std::vector& element_types() const { + return element_types_; + } + std::vector& element_types() { return element_types_; } bool decoration_empty() const override { return decorations_.empty() && element_decorations_.empty(); } @@ -379,15 +423,18 @@ class Struct : public Type { Struct* AsStruct() override { return this; } const Struct* AsStruct() const override { return this; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + void ClearDecorations() override { decorations_.clear(); element_decorations_.clear(); } - std::vector element_types_; + std::vector element_types_; // We can attach decorations to struct members and that should not affect the // underlying element type. So we need an extra data structure here to keep // track of element type decorations. They must be stored in an ordered map @@ -401,7 +448,6 @@ class Opaque : public Type { Opaque(std::string n) : Type(kOpaque), name_(std::move(n)) {} Opaque(const Opaque&) = default; - bool IsSame(const Type* that) const override; std::string str() const override; Opaque* AsOpaque() override { return this; } @@ -409,9 +455,12 @@ class Opaque : public Type { const std::string& name() const { return name_; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + std::string name_; }; @@ -420,7 +469,6 @@ class Pointer : public Type { Pointer(const Type* pointee, SpvStorageClass sc); Pointer(const Pointer&) = default; - bool IsSame(const Type* that) const override; std::string str() const override; const Type* pointee_type() const { return pointee_type_; } SpvStorageClass storage_class() const { return storage_class_; } @@ -428,32 +476,42 @@ class Pointer : public Type { Pointer* AsPointer() override { return this; } const Pointer* AsPointer() const override { return this; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + void SetPointeeType(const Type* type); private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + const Type* pointee_type_; SpvStorageClass storage_class_; }; class Function : public Type { public: - Function(Type* ret_type, const std::vector& params); + Function(Type* ret_type, const std::vector& params); Function(const Function&) = default; - bool IsSame(const Type* that) const override; std::string str() const override; Function* AsFunction() override { return this; } const Function* AsFunction() const override { return this; } const Type* return_type() const { return return_type_; } - const std::vector& param_types() const { return param_types_; } + const std::vector& param_types() const { return param_types_; } + std::vector& param_types() { return param_types_; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set*) const override; + + void SetReturnType(const Type* type); private: - Type* return_type_; - std::vector param_types_; + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const Type* return_type_; + std::vector param_types_; }; class Pipe : public Type { @@ -462,7 +520,6 @@ class Pipe : public Type { : Type(kPipe), access_qualifier_(qualifier) {} Pipe(const Pipe&) = default; - bool IsSame(const Type* that) const override; std::string str() const override; Pipe* AsPipe() override { return this; } @@ -470,9 +527,12 @@ class Pipe : public Type { SpvAccessQualifier access_qualifier() const { return access_qualifier_; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + SpvAccessQualifier access_qualifier_; }; @@ -486,39 +546,44 @@ class ForwardPointer : public Type { ForwardPointer(const ForwardPointer&) = default; uint32_t target_id() const { return target_id_; } - void SetTargetPointer(Pointer* pointer) { pointer_ = pointer; } + void SetTargetPointer(const Pointer* pointer) { pointer_ = pointer; } SpvStorageClass storage_class() const { return storage_class_; } const Pointer* target_pointer() const { return pointer_; } - bool IsSame(const Type* that) const override; std::string str() const override; ForwardPointer* AsForwardPointer() override { return this; } const ForwardPointer* AsForwardPointer() const override { return this; } - void GetExtraHashWords(std::vector* words) const override; + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + uint32_t target_id_; SpvStorageClass storage_class_; - Pointer* pointer_; + const Pointer* pointer_; }; -#define DefineParameterlessType(type, name) \ - class type : public Type { \ - public: \ - type() : Type(k##type) {} \ - type(const type&) = default; \ - \ - bool IsSame(const Type* that) const override { \ - return that->As##type() && HasSameDecorations(that); \ - } \ - std::string str() const override { return #name; } \ - \ - type* As##type() override { return this; } \ - const type* As##type() const override { return this; } \ - \ - void GetExtraHashWords(std::vector*) const override {} \ +#define DefineParameterlessType(type, name) \ + class type : public Type { \ + public: \ + type() : Type(k##type) {} \ + type(const type&) = default; \ + \ + std::string str() const override { return #name; } \ + \ + type* As##type() override { return this; } \ + const type* As##type() const override { return this; } \ + \ + void GetExtraHashWords(std::vector*, \ + std::unordered_set*) const override {} \ + \ + private: \ + bool IsSameImpl(const Type* that, IsSameCache*) const override { \ + return that->As##type() && HasSameDecorations(that); \ + } \ } DefineParameterlessType(Void, void); DefineParameterlessType(Bool, bool); @@ -535,4 +600,4 @@ DefineParameterlessType(NamedBarrier, named_barrier); } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_TYPES_H_ +#endif // SOURCE_OPT_TYPES_H_ diff --git a/3rdparty/spirv-tools/source/opt/unify_const_pass.cpp b/3rdparty/spirv-tools/source/opt/unify_const_pass.cpp index 266757339..227fd61da 100644 --- a/3rdparty/spirv-tools/source/opt/unify_const_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/unify_const_pass.cpp @@ -12,14 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unify_const_pass.h" +#include "source/opt/unify_const_pass.h" +#include #include #include +#include -#include "def_use_manager.h" -#include "ir_context.h" -#include "make_unique.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/util/make_unique.h" namespace spvtools { namespace opt { @@ -39,7 +41,7 @@ class ResultIdTrie { // is found, creates a trie node with those keys, stores the instruction's // result id and returns that result id. If an existing result id is found, // returns the existing result id. - uint32_t LookupEquivalentResultFor(const ir::Instruction& inst) { + uint32_t LookupEquivalentResultFor(const Instruction& inst) { auto keys = GetLookUpKeys(inst); auto* node = root_.get(); for (uint32_t key : keys) { @@ -85,7 +87,7 @@ class ResultIdTrie { // Returns a vector of the opcode followed by the words in the raw SPIR-V // instruction encoding but without the result id. - std::vector GetLookUpKeys(const ir::Instruction& inst) { + std::vector GetLookUpKeys(const Instruction& inst) { std::vector keys; // Need to use the opcode, otherwise there might be a conflict with the // following case when 's binary value equals xx's id: @@ -103,12 +105,13 @@ class ResultIdTrie { }; } // anonymous namespace -Pass::Status UnifyConstantPass::Process(ir::IRContext* c) { - InitializeProcessing(c); +Pass::Status UnifyConstantPass::Process() { bool modified = false; ResultIdTrie defined_constants; - for( ir::Instruction* next_instruction, *inst = &*(context()->types_values_begin()); inst; inst = next_instruction) { + for (Instruction *next_instruction, + *inst = &*(context()->types_values_begin()); + inst; inst = next_instruction) { next_instruction = inst->NextNode(); // Do not handle the instruction when there are decorations upon the result diff --git a/3rdparty/spirv-tools/source/opt/unify_const_pass.h b/3rdparty/spirv-tools/source/opt/unify_const_pass.h index 648381765..f2b7897cc 100644 --- a/3rdparty/spirv-tools/source/opt/unify_const_pass.h +++ b/3rdparty/spirv-tools/source/opt/unify_const_pass.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_UNIFY_CONSTANT_PASS_H_ -#define LIBSPIRV_OPT_UNIFY_CONSTANT_PASS_H_ +#ifndef SOURCE_OPT_UNIFY_CONST_PASS_H_ +#define SOURCE_OPT_UNIFY_CONST_PASS_H_ -#include "ir_context.h" -#include "module.h" -#include "pass.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -26,10 +26,10 @@ namespace opt { class UnifyConstantPass : public Pass { public: const char* name() const override { return "unify-const"; } - Status Process(ir::IRContext*) override; + Status Process() override; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_UNIFY_CONSTANT_PASS_H_ +#endif // SOURCE_OPT_UNIFY_CONST_PASS_H_ diff --git a/3rdparty/spirv-tools/source/opt/value_number_table.cpp b/3rdparty/spirv-tools/source/opt/value_number_table.cpp index 7f5b7ce47..1bac63fab 100644 --- a/3rdparty/spirv-tools/source/opt/value_number_table.cpp +++ b/3rdparty/spirv-tools/source/opt/value_number_table.cpp @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "value_number_table.h" +#include "source/opt/value_number_table.h" #include -#include "cfg.h" +#include "source/opt/cfg.h" +#include "source/opt/ir_context.h" namespace spvtools { namespace opt { -uint32_t ValueNumberTable::GetValueNumber( - spvtools::ir::Instruction* inst) const { +uint32_t ValueNumberTable::GetValueNumber(Instruction* inst) const { assert(inst->result_id() != 0 && "inst must have a result id to get a value number."); @@ -34,7 +34,11 @@ uint32_t ValueNumberTable::GetValueNumber( return 0; } -uint32_t ValueNumberTable::AssignValueNumber(ir::Instruction* inst) { +uint32_t ValueNumberTable::GetValueNumber(uint32_t id) const { + return GetValueNumber(context()->get_def_use_mgr()->GetDef(id)); +} + +uint32_t ValueNumberTable::AssignValueNumber(Instruction* inst) { // If it already has a value return that. uint32_t value = GetValueNumber(inst); if (value != 0) { @@ -103,19 +107,19 @@ uint32_t ValueNumberTable::AssignValueNumber(ir::Instruction* inst) { // Replace all of the operands by their value number. The sign bit will be // set to distinguish between an id and a value number. - ir::Instruction value_ins(context(), inst->opcode(), inst->type_id(), - inst->result_id(), {}); + Instruction value_ins(context(), inst->opcode(), inst->type_id(), + inst->result_id(), {}); for (uint32_t o = 0; o < inst->NumInOperands(); ++o) { - const ir::Operand& op = inst->GetInOperand(o); + const Operand& op = inst->GetInOperand(o); if (spvIsIdType(op.type)) { uint32_t id_value = op.words[0]; auto use_id_to_val = id_to_value_.find(id_value); if (use_id_to_val != id_to_value_.end()) { id_value = (1 << 31) | use_id_to_val->second; } - value_ins.AddOperand(ir::Operand(op.type, {id_value})); + value_ins.AddOperand(Operand(op.type, {id_value})); } else { - value_ins.AddOperand(ir::Operand(op.type, op.words)); + value_ins.AddOperand(Operand(op.type, op.words)); } } @@ -163,11 +167,11 @@ void ValueNumberTable::BuildDominatorTreeValueNumberTable() { } } - for (ir::Function& func : *context()->module()) { + for (Function& func : *context()->module()) { // For best results we want to traverse the code in reverse post order. // This happens naturally because of the forward referencing rules. - for (ir::BasicBlock& block : func) { - for (ir::Instruction& inst : block) { + for (BasicBlock& block : func) { + for (Instruction& inst : block) { if (inst.result_id() != 0) { AssignValueNumber(&inst); } @@ -176,8 +180,8 @@ void ValueNumberTable::BuildDominatorTreeValueNumberTable() { } } -bool ComputeSameValue::operator()(const ir::Instruction& lhs, - const ir::Instruction& rhs) const { +bool ComputeSameValue::operator()(const Instruction& lhs, + const Instruction& rhs) const { if (lhs.result_id() == 0 || rhs.result_id() == 0) { return false; } @@ -204,8 +208,7 @@ bool ComputeSameValue::operator()(const ir::Instruction& lhs, lhs.result_id(), rhs.result_id()); } -std::size_t ValueTableHash::operator()( - const spvtools::ir::Instruction& inst) const { +std::size_t ValueTableHash::operator()(const Instruction& inst) const { // We hash the opcode and in-operands, not the result, because we want // instructions that are the same except for the result to hash to the // same value. diff --git a/3rdparty/spirv-tools/source/opt/value_number_table.h b/3rdparty/spirv-tools/source/opt/value_number_table.h index 8ad20df34..39129ffa3 100644 --- a/3rdparty/spirv-tools/source/opt/value_number_table.h +++ b/3rdparty/spirv-tools/source/opt/value_number_table.h @@ -12,28 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_VALUE_NUMBER_TABLE_H_ -#define LIBSPIRV_OPT_VALUE_NUMBER_TABLE_H_ +#ifndef SOURCE_OPT_VALUE_NUMBER_TABLE_H_ +#define SOURCE_OPT_VALUE_NUMBER_TABLE_H_ #include #include -#include "instruction.h" -#include "ir_context.h" + +#include "source/opt/instruction.h" namespace spvtools { namespace opt { +class IRContext; + // Returns true if the two instructions compute the same value. Used by the // value number table to compare two instructions. class ComputeSameValue { public: - bool operator()(const ir::Instruction& lhs, const ir::Instruction& rhs) const; + bool operator()(const Instruction& lhs, const Instruction& rhs) const; }; // The hash function used in the value number table. class ValueTableHash { public: - std::size_t operator()(const spvtools::ir::Instruction& inst) const; + std::size_t operator()(const Instruction& inst) const; }; // This class implements the value number analysis. It is using a hash-based @@ -49,20 +51,20 @@ class ValueTableHash { // the scope. class ValueNumberTable { public: - ValueNumberTable(ir::IRContext* ctx) : context_(ctx), next_value_number_(1) { + ValueNumberTable(IRContext* ctx) : context_(ctx), next_value_number_(1) { BuildDominatorTreeValueNumberTable(); } // Returns the value number of the value computed by |inst|. |inst| must have // a result id that will hold the computed value. If no value number has been // assigned to the result id, then the return value is 0. - uint32_t GetValueNumber(spvtools::ir::Instruction* inst) const; + uint32_t GetValueNumber(Instruction* inst) const; // Returns the value number of the value contain in |id|. Returns 0 if it // has not been assigned a value number. - inline uint32_t GetValueNumber(uint32_t id) const; + uint32_t GetValueNumber(uint32_t id) const; - ir::IRContext* context() const { return context_; } + IRContext* context() const { return context_; } private: // Assigns a value number to every result id in the module. @@ -74,21 +76,16 @@ class ValueNumberTable { // Assigns a new value number to the result of |inst| if it does not already // have one. Return the value number for |inst|. |inst| must have a result // id. - uint32_t AssignValueNumber(ir::Instruction* inst); + uint32_t AssignValueNumber(Instruction* inst); - std::unordered_map + std::unordered_map instruction_to_value_; std::unordered_map id_to_value_; - ir::IRContext* context_; + IRContext* context_; uint32_t next_value_number_; }; -uint32_t ValueNumberTable::GetValueNumber(uint32_t id) const { - return GetValueNumber(context()->get_def_use_mgr()->GetDef(id)); -} - } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_VALUE_NUMBER_TABLE_H_ +#endif // SOURCE_OPT_VALUE_NUMBER_TABLE_H_ diff --git a/3rdparty/spirv-tools/source/opt/vector_dce.cpp b/3rdparty/spirv-tools/source/opt/vector_dce.cpp new file mode 100644 index 000000000..911242e05 --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/vector_dce.cpp @@ -0,0 +1,367 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/vector_dce.h" + +#include + +namespace spvtools { +namespace opt { +namespace { + +const uint32_t kExtractCompositeIdInIdx = 0; +const uint32_t kInsertObjectIdInIdx = 0; +const uint32_t kInsertCompositeIdInIdx = 1; + +} // namespace + +Pass::Status VectorDCE::Process() { + bool modified = false; + for (Function& function : *get_module()) { + modified |= VectorDCEFunction(&function); + } + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool VectorDCE::VectorDCEFunction(Function* function) { + LiveComponentMap live_components; + FindLiveComponents(function, &live_components); + return RewriteInstructions(function, live_components); +} + +void VectorDCE::FindLiveComponents(Function* function, + LiveComponentMap* live_components) { + std::vector work_list; + + // Prime the work list. We will assume that any instruction that does + // not result in a vector is live. + // + // Extending to structures and matrices is not as straight forward because of + // the nesting. We cannot simply us a bit vector to keep track of which + // components are live because of arbitrary nesting of structs. + function->ForEachInst( + [&work_list, this, live_components](Instruction* current_inst) { + if (!HasVectorOrScalarResult(current_inst) || + !context()->IsCombinatorInstruction(current_inst)) { + MarkUsesAsLive(current_inst, all_components_live_, live_components, + &work_list); + } + }); + + // Process the work list propagating liveness. + for (uint32_t i = 0; i < work_list.size(); i++) { + WorkListItem current_item = work_list[i]; + Instruction* current_inst = current_item.instruction; + + switch (current_inst->opcode()) { + case SpvOpCompositeExtract: + MarkExtractUseAsLive(current_inst, live_components, &work_list); + break; + case SpvOpCompositeInsert: + MarkInsertUsesAsLive(current_item, live_components, &work_list); + break; + case SpvOpVectorShuffle: + MarkVectorShuffleUsesAsLive(current_item, live_components, &work_list); + break; + case SpvOpCompositeConstruct: + MarkCompositeContructUsesAsLive(current_item, live_components, + &work_list); + break; + default: + if (current_inst->IsScalarizable()) { + MarkUsesAsLive(current_inst, current_item.components, live_components, + &work_list); + } else { + MarkUsesAsLive(current_inst, all_components_live_, live_components, + &work_list); + } + break; + } + } +} + +void VectorDCE::MarkExtractUseAsLive(const Instruction* current_inst, + LiveComponentMap* live_components, + std::vector* work_list) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + uint32_t operand_id = + current_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); + Instruction* operand_inst = def_use_mgr->GetDef(operand_id); + + if (HasVectorOrScalarResult(operand_inst)) { + WorkListItem new_item; + new_item.instruction = operand_inst; + new_item.components.Set(current_inst->GetSingleWordInOperand(1)); + AddItemToWorkListIfNeeded(new_item, live_components, work_list); + } +} + +void VectorDCE::MarkInsertUsesAsLive( + const VectorDCE::WorkListItem& current_item, + LiveComponentMap* live_components, + std::vector* work_list) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + + uint32_t insert_position = + current_item.instruction->GetSingleWordInOperand(2); + + // Add the elements of the composite object that are used. + uint32_t operand_id = + current_item.instruction->GetSingleWordInOperand(kInsertCompositeIdInIdx); + Instruction* operand_inst = def_use_mgr->GetDef(operand_id); + + WorkListItem new_item; + new_item.instruction = operand_inst; + new_item.components = current_item.components; + new_item.components.Clear(insert_position); + + AddItemToWorkListIfNeeded(new_item, live_components, work_list); + + // Add the element being inserted if it is used. + if (current_item.components.Get(insert_position)) { + uint32_t obj_operand_id = + current_item.instruction->GetSingleWordInOperand(kInsertObjectIdInIdx); + Instruction* obj_operand_inst = def_use_mgr->GetDef(obj_operand_id); + WorkListItem new_item_for_obj; + new_item_for_obj.instruction = obj_operand_inst; + new_item_for_obj.components.Set(0); + AddItemToWorkListIfNeeded(new_item_for_obj, live_components, work_list); + } +} + +void VectorDCE::MarkVectorShuffleUsesAsLive( + const WorkListItem& current_item, + VectorDCE::LiveComponentMap* live_components, + std::vector* work_list) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + + WorkListItem first_operand; + first_operand.instruction = + def_use_mgr->GetDef(current_item.instruction->GetSingleWordInOperand(0)); + WorkListItem second_operand; + second_operand.instruction = + def_use_mgr->GetDef(current_item.instruction->GetSingleWordInOperand(1)); + + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::Vector* first_type = + type_mgr->GetType(first_operand.instruction->type_id())->AsVector(); + uint32_t size_of_first_operand = first_type->element_count(); + + for (uint32_t in_op = 2; in_op < current_item.instruction->NumInOperands(); + ++in_op) { + uint32_t index = current_item.instruction->GetSingleWordInOperand(in_op); + if (current_item.components.Get(in_op - 2)) { + if (index < size_of_first_operand) { + first_operand.components.Set(index); + } else { + second_operand.components.Set(index - size_of_first_operand); + } + } + } + + AddItemToWorkListIfNeeded(first_operand, live_components, work_list); + AddItemToWorkListIfNeeded(second_operand, live_components, work_list); +} + +void VectorDCE::MarkCompositeContructUsesAsLive( + VectorDCE::WorkListItem work_item, + VectorDCE::LiveComponentMap* live_components, + std::vector* work_list) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + + uint32_t current_component = 0; + Instruction* current_inst = work_item.instruction; + uint32_t num_in_operands = current_inst->NumInOperands(); + for (uint32_t i = 0; i < num_in_operands; ++i) { + uint32_t id = current_inst->GetSingleWordInOperand(i); + Instruction* op_inst = def_use_mgr->GetDef(id); + + if (HasScalarResult(op_inst)) { + WorkListItem new_work_item; + new_work_item.instruction = op_inst; + if (work_item.components.Get(current_component)) { + new_work_item.components.Set(0); + } + AddItemToWorkListIfNeeded(new_work_item, live_components, work_list); + current_component++; + } else { + assert(HasVectorResult(op_inst)); + WorkListItem new_work_item; + new_work_item.instruction = op_inst; + uint32_t op_vector_size = + type_mgr->GetType(op_inst->type_id())->AsVector()->element_count(); + + for (uint32_t op_vector_idx = 0; op_vector_idx < op_vector_size; + op_vector_idx++, current_component++) { + if (work_item.components.Get(current_component)) { + new_work_item.components.Set(op_vector_idx); + } + } + AddItemToWorkListIfNeeded(new_work_item, live_components, work_list); + } + } +} + +void VectorDCE::MarkUsesAsLive( + Instruction* current_inst, const utils::BitVector& live_elements, + LiveComponentMap* live_components, + std::vector* work_list) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + + current_inst->ForEachInId([&work_list, &live_elements, this, live_components, + def_use_mgr](uint32_t* operand_id) { + Instruction* operand_inst = def_use_mgr->GetDef(*operand_id); + + if (HasVectorResult(operand_inst)) { + WorkListItem new_item; + new_item.instruction = operand_inst; + new_item.components = live_elements; + AddItemToWorkListIfNeeded(new_item, live_components, work_list); + } else if (HasScalarResult(operand_inst)) { + WorkListItem new_item; + new_item.instruction = operand_inst; + new_item.components.Set(0); + AddItemToWorkListIfNeeded(new_item, live_components, work_list); + } + }); +} + +bool VectorDCE::HasVectorOrScalarResult(const Instruction* inst) const { + return HasScalarResult(inst) || HasVectorResult(inst); +} + +bool VectorDCE::HasVectorResult(const Instruction* inst) const { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + if (inst->type_id() == 0) { + return false; + } + + const analysis::Type* current_type = type_mgr->GetType(inst->type_id()); + switch (current_type->kind()) { + case analysis::Type::kVector: + return true; + default: + return false; + } +} + +bool VectorDCE::HasScalarResult(const Instruction* inst) const { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + if (inst->type_id() == 0) { + return false; + } + + const analysis::Type* current_type = type_mgr->GetType(inst->type_id()); + switch (current_type->kind()) { + case analysis::Type::kBool: + case analysis::Type::kInteger: + case analysis::Type::kFloat: + return true; + default: + return false; + } +} + +bool VectorDCE::RewriteInstructions( + Function* function, const VectorDCE::LiveComponentMap& live_components) { + bool modified = false; + function->ForEachInst( + [&modified, this, live_components](Instruction* current_inst) { + if (!context()->IsCombinatorInstruction(current_inst)) { + return; + } + + auto live_component = live_components.find(current_inst->result_id()); + if (live_component == live_components.end()) { + // If this instruction is not in live_components then it does not + // produce a vector, or it is never referenced and ADCE will remove + // it. No point in trying to differentiate. + return; + } + + // If no element in the current instruction is used replace it with an + // OpUndef. + if (live_component->second.Empty()) { + modified = true; + uint32_t undef_id = this->Type2Undef(current_inst->type_id()); + context()->KillNamesAndDecorates(current_inst); + context()->ReplaceAllUsesWith(current_inst->result_id(), undef_id); + context()->KillInst(current_inst); + return; + } + + switch (current_inst->opcode()) { + case SpvOpCompositeInsert: + modified |= + RewriteInsertInstruction(current_inst, live_component->second); + break; + case SpvOpCompositeConstruct: + // TODO: The members that are not live can be replaced by an undef + // or constant. This will remove uses of those values, and possibly + // create opportunities for ADCE. + break; + default: + // Do nothing. + break; + } + }); + return modified; +} + +bool VectorDCE::RewriteInsertInstruction( + Instruction* current_inst, const utils::BitVector& live_components) { + // If the value being inserted is not live, then we can skip the insert. + bool modified = false; + uint32_t insert_index = current_inst->GetSingleWordInOperand(2); + if (!live_components.Get(insert_index)) { + modified = true; + context()->KillNamesAndDecorates(current_inst->result_id()); + uint32_t composite_id = + current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx); + context()->ReplaceAllUsesWith(current_inst->result_id(), composite_id); + } + + // If the values already in the composite are not used, then replace it with + // an undef. + utils::BitVector temp = live_components; + temp.Clear(insert_index); + if (temp.Empty()) { + context()->ForgetUses(current_inst); + uint32_t undef_id = Type2Undef(current_inst->type_id()); + current_inst->SetInOperand(kInsertCompositeIdInIdx, {undef_id}); + context()->AnalyzeUses(current_inst); + } + + return modified; +} + +void VectorDCE::AddItemToWorkListIfNeeded( + WorkListItem work_item, VectorDCE::LiveComponentMap* live_components, + std::vector* work_list) { + Instruction* current_inst = work_item.instruction; + auto it = live_components->find(current_inst->result_id()); + if (it == live_components->end()) { + live_components->emplace( + std::make_pair(current_inst->result_id(), work_item.components)); + work_list->emplace_back(work_item); + } else { + if (it->second.Or(work_item.components)) { + work_list->emplace_back(work_item); + } + } +} + +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/vector_dce.h b/3rdparty/spirv-tools/source/opt/vector_dce.h new file mode 100644 index 000000000..48886998d --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/vector_dce.h @@ -0,0 +1,149 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_VECTOR_DCE_H_ +#define SOURCE_OPT_VECTOR_DCE_H_ + +#include +#include + +#include "source/opt/mem_pass.h" +#include "source/util/bit_vector.h" + +namespace spvtools { +namespace opt { + +class VectorDCE : public MemPass { + private: + using LiveComponentMap = std::unordered_map; + + // According to the SPEC the maximum size for a vector is 16. See the data + // rules in the universal validation rules (section 2.16.1). + enum { kMaxVectorSize = 16 }; + + struct WorkListItem { + WorkListItem() : instruction(nullptr), components(kMaxVectorSize) {} + + Instruction* instruction; + utils::BitVector components; + }; + + public: + VectorDCE() : all_components_live_(kMaxVectorSize) { + for (uint32_t i = 0; i < kMaxVectorSize; i++) { + all_components_live_.Set(i); + } + } + + const char* name() const override { return "vector-dce"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisCFG | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisDecorations | + IRContext::kAnalysisDominatorAnalysis | IRContext::kAnalysisNameMap; + } + + private: + // Runs the vector dce pass on |function|. Returns true if |function| was + // modified. + bool VectorDCEFunction(Function* function); + + // Identifies the live components of the vectors that are results of + // instructions in |function|. The results are stored in |live_components|. + void FindLiveComponents(Function* function, + LiveComponentMap* live_components); + + // Rewrites instructions in |function| that are dead or partially dead. If an + // instruction does not have an entry in |live_components|, then it is not + // changed. Returns true if |function| was modified. + bool RewriteInstructions(Function* function, + const LiveComponentMap& live_components); + + // Rewrites the OpCompositeInsert instruction |current_inst| to avoid + // unnecessary computes given that the only components of the result that are + // live are |live_components|. + // + // If the value being inserted is not live, then the result of |current_inst| + // is replaced by the composite input to |current_inst|. + // + // If the composite input to |current_inst| is not live, then it is replaced + // by and OpUndef in |current_inst|. + bool RewriteInsertInstruction(Instruction* current_inst, + const utils::BitVector& live_components); + + // Returns true if the result of |inst| is a vector or a scalar. + bool HasVectorOrScalarResult(const Instruction* inst) const; + + // Returns true if the result of |inst| is a scalar. + bool HasVectorResult(const Instruction* inst) const; + + // Returns true if the result of |inst| is a vector. + bool HasScalarResult(const Instruction* inst) const; + + // Adds |work_item| to |work_list| if it is not already live according to + // |live_components|. |live_components| is updated to indicate that + // |work_item| is now live. + void AddItemToWorkListIfNeeded(WorkListItem work_item, + LiveComponentMap* live_components, + std::vector* work_list); + + // Marks the components |live_elements| of the uses in |current_inst| as live + // according to |live_components|. If they were not live before, then they are + // added to |work_list|. + void MarkUsesAsLive(Instruction* current_inst, + const utils::BitVector& live_elements, + LiveComponentMap* live_components, + std::vector* work_list); + + // Marks the uses in the OpVectorShuffle instruction in |current_item| as live + // based on the live components in |current_item|. If anything becomes live + // they are added to |work_list| and |live_components| is updated + // accordingly. + void MarkVectorShuffleUsesAsLive(const WorkListItem& current_item, + VectorDCE::LiveComponentMap* live_components, + std::vector* work_list); + + // Marks the uses in the OpCompositeInsert instruction in |current_item| as + // live based on the live components in |current_item|. If anything becomes + // live they are added to |work_list| and |live_components| is updated + // accordingly. + void MarkInsertUsesAsLive(const WorkListItem& current_item, + LiveComponentMap* live_components, + std::vector* work_list); + + // Marks the uses in the OpCompositeExtract instruction |current_inst| as + // live. If anything becomes live they are added to |work_list| and + // |live_components| is updated accordingly. + void MarkExtractUseAsLive(const Instruction* current_inst, + LiveComponentMap* live_components, + std::vector* work_list); + + // Marks the uses in the OpCompositeConstruct instruction |current_inst| as + // live. If anything becomes live they are added to |work_list| and + // |live_components| is updated accordingly. + void MarkCompositeContructUsesAsLive(WorkListItem work_item, + LiveComponentMap* live_components, + std::vector* work_list); + + // A BitVector that can always be used to say that all components of a vector + // are live. + utils::BitVector all_components_live_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_VECTOR_DCE_H_ diff --git a/3rdparty/spirv-tools/source/opt/workaround1209.cpp b/3rdparty/spirv-tools/source/opt/workaround1209.cpp index ae05848ea..d6e9d2cf7 100644 --- a/3rdparty/spirv-tools/source/opt/workaround1209.cpp +++ b/3rdparty/spirv-tools/source/opt/workaround1209.cpp @@ -12,16 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "workaround1209.h" +#include "source/opt/workaround1209.h" #include +#include #include +#include namespace spvtools { namespace opt { -Pass::Status Workaround1209::Process(ir::IRContext* c) { - InitializeProcessing(c); +Pass::Status Workaround1209::Process() { bool modified = false; modified = RemoveOpUnreachableInLoops(); return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); @@ -30,14 +31,14 @@ Pass::Status Workaround1209::Process(ir::IRContext* c) { bool Workaround1209::RemoveOpUnreachableInLoops() { bool modified = false; for (auto& func : *get_module()) { - std::list structured_order; + std::list structured_order; cfg()->ComputeStructuredOrder(&func, &*func.begin(), &structured_order); // Keep track of the loop merges. The top of the stack will always be the // loop merge for the loop that immediately contains the basic block being // processed. std::stack loop_merges; - for (ir::BasicBlock* bb : structured_order) { + for (BasicBlock* bb : structured_order) { if (!loop_merges.empty() && bb->id() == loop_merges.top()) { loop_merges.pop(); } @@ -47,10 +48,10 @@ bool Workaround1209::RemoveOpUnreachableInLoops() { // We found an OpUnreachable inside a loop. // Replace it with an unconditional branch to the loop merge. context()->KillInst(&*bb->tail()); - std::unique_ptr new_branch( - new ir::Instruction(context(), SpvOpBranch, 0, 0, - {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, - {loop_merges.top()}}})); + std::unique_ptr new_branch( + new Instruction(context(), SpvOpBranch, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, + {loop_merges.top()}}})); context()->AnalyzeDefUse(&*new_branch); bb->AddInstruction(std::move(new_branch)); modified = true; diff --git a/3rdparty/spirv-tools/source/opt/workaround1209.h b/3rdparty/spirv-tools/source/opt/workaround1209.h index 2265ac3f9..9a1f88d93 100644 --- a/3rdparty/spirv-tools/source/opt/workaround1209.h +++ b/3rdparty/spirv-tools/source/opt/workaround1209.h @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_WORKAROUND1209_H_ -#define LIBSPIRV_OPT_WORKAROUND1209_H_ +#ifndef SOURCE_OPT_WORKAROUND1209_H_ +#define SOURCE_OPT_WORKAROUND1209_H_ -#include "pass.h" +#include "source/opt/pass.h" namespace spvtools { namespace opt { @@ -24,7 +24,7 @@ namespace opt { class Workaround1209 : public Pass { public: const char* name() const override { return "workaround-1209"; } - Status Process(ir::IRContext*) override; + Status Process() override; private: // There is at least one driver where an OpUnreachable found in a loop is not @@ -38,4 +38,4 @@ class Workaround1209 : public Pass { } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_WORKAROUND1209_H_ +#endif // SOURCE_OPT_WORKAROUND1209_H_ diff --git a/3rdparty/spirv-tools/source/parsed_operand.cpp b/3rdparty/spirv-tools/source/parsed_operand.cpp index 6f3ffe8d5..7ad369cdb 100644 --- a/3rdparty/spirv-tools/source/parsed_operand.cpp +++ b/3rdparty/spirv-tools/source/parsed_operand.cpp @@ -14,19 +14,21 @@ // This file contains utility functions for spv_parsed_operand_t. -#include "parsed_operand.h" +#include "source/parsed_operand.h" #include -#include "util/hex_float.h" +#include "source/util/hex_float.h" -namespace libspirv { +namespace spvtools { void EmitNumericLiteral(std::ostream* out, const spv_parsed_instruction_t& inst, const spv_parsed_operand_t& operand) { - assert(operand.type == SPV_OPERAND_TYPE_LITERAL_INTEGER || - operand.type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER); - assert(1 <= operand.num_words); - assert(operand.num_words <= 2); + if (operand.type != SPV_OPERAND_TYPE_LITERAL_INTEGER && + operand.type != SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER) + return; + if (operand.num_words < 1) return; + // TODO(dneto): Support more than 64-bits at a time. + if (operand.num_words > 2) return; const uint32_t word = inst.words[operand.offset]; if (operand.num_words == 1) { @@ -39,15 +41,15 @@ void EmitNumericLiteral(std::ostream* out, const spv_parsed_instruction_t& inst, break; case SPV_NUMBER_FLOATING: if (operand.number_bit_width == 16) { - *out << spvutils::FloatProxy( + *out << spvtools::utils::FloatProxy( uint16_t(word & 0xFFFF)); } else { // Assume 32-bit floats. - *out << spvutils::FloatProxy(word); + *out << spvtools::utils::FloatProxy(word); } break; default: - assert(false && "Unreachable"); + break; } } else if (operand.num_words == 2) { // Multi-word numbers are presented with lower order words first. @@ -62,14 +64,11 @@ void EmitNumericLiteral(std::ostream* out, const spv_parsed_instruction_t& inst, break; case SPV_NUMBER_FLOATING: // Assume only 64-bit floats. - *out << spvutils::FloatProxy(bits); + *out << spvtools::utils::FloatProxy(bits); break; default: - assert(false && "Unreachable"); + break; } - } else { - // TODO(dneto): Support more than 64-bits at a time. - assert(false && "Unhandled"); } } -} // namespace libspirv +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/parsed_operand.h b/3rdparty/spirv-tools/source/parsed_operand.h index 8c2fd85ce..bab861107 100644 --- a/3rdparty/spirv-tools/source/parsed_operand.h +++ b/3rdparty/spirv-tools/source/parsed_operand.h @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_PARSED_OPERAND_H_ -#define LIBSPIRV_PARSED_OPERAND_H_ +#ifndef SOURCE_PARSED_OPERAND_H_ +#define SOURCE_PARSED_OPERAND_H_ #include + #include "spirv-tools/libspirv.h" -namespace libspirv { +namespace spvtools { // Emits the numeric literal representation of the given instruction operand // to the stream. The operand must be of numeric type. If integral it may @@ -27,6 +28,6 @@ namespace libspirv { void EmitNumericLiteral(std::ostream* out, const spv_parsed_instruction_t& inst, const spv_parsed_operand_t& operand); -} // namespace libspirv +} // namespace spvtools -#endif // LIBSPIRV_BINARY_H_ +#endif // SOURCE_PARSED_OPERAND_H_ diff --git a/3rdparty/spirv-tools/source/print.cpp b/3rdparty/spirv-tools/source/print.cpp index 70d8f5965..f75e2d457 100644 --- a/3rdparty/spirv-tools/source/print.cpp +++ b/3rdparty/spirv-tools/source/print.cpp @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "print.h" +#include "source/print.h" #if defined(SPIRV_ANDROID) || defined(SPIRV_LINUX) || defined(SPIRV_MAC) || \ defined(SPIRV_FREEBSD) -namespace libspirv { +namespace spvtools { clr::reset::operator const char*() { return "\x1b[0m"; } @@ -30,11 +30,11 @@ clr::yellow::operator const char*() { return "\x1b[33m"; } clr::blue::operator const char*() { return "\x1b[34m"; } -} // namespace libspirv +} // namespace spvtools #elif defined(SPIRV_WINDOWS) #include -namespace libspirv { +namespace spvtools { static void SetConsoleForegroundColorPrimary(HANDLE hConsole, WORD color) { // Get screen buffer information from console handle @@ -105,9 +105,9 @@ clr::blue::operator const char*() { return "\x1b[94m"; } -} // namespace libspirv +} // namespace spvtools #else -namespace libspirv { +namespace spvtools { clr::reset::operator const char*() { return ""; } @@ -121,5 +121,5 @@ clr::yellow::operator const char*() { return ""; } clr::blue::operator const char*() { return ""; } -} // namespace libspirv +} // namespace spvtools #endif diff --git a/3rdparty/spirv-tools/source/print.h b/3rdparty/spirv-tools/source/print.h index 76d7c40fc..f31ba38e7 100644 --- a/3rdparty/spirv-tools/source/print.h +++ b/3rdparty/spirv-tools/source/print.h @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_PRINT_H_ -#define LIBSPIRV_PRINT_H_ +#ifndef SOURCE_PRINT_H_ +#define SOURCE_PRINT_H_ #include #include -namespace libspirv { +namespace spvtools { // Wrapper for out stream selection. class out_stream { @@ -70,6 +70,6 @@ struct blue { }; } // namespace clr -} // namespace libspirv +} // namespace spvtools -#endif // LIBSPIRV_PRINT_H_ +#endif // SOURCE_PRINT_H_ diff --git a/3rdparty/spirv-tools/source/spirv_constant.h b/3rdparty/spirv-tools/source/spirv_constant.h index 8eb6572bc..39771ccb2 100644 --- a/3rdparty/spirv-tools/source/spirv_constant.h +++ b/3rdparty/spirv-tools/source/spirv_constant.h @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_SPIRV_CONSTANT_H_ -#define LIBSPIRV_SPIRV_CONSTANT_H_ +#ifndef SOURCE_SPIRV_CONSTANT_H_ +#define SOURCE_SPIRV_CONSTANT_H_ -#include "latest_version_spirv_header.h" +#include "source/latest_version_spirv_header.h" #include "spirv-tools/libspirv.h" // Version number macros. @@ -97,4 +97,4 @@ typedef enum spv_generator_t { // Returns the misc part of the generator word. #define SPV_GENERATOR_MISC_PART(WORD) (uint32_t(WORD) & 0xFFFF) -#endif // LIBSPIRV_SPIRV_CONSTANT_H_ +#endif // SOURCE_SPIRV_CONSTANT_H_ diff --git a/3rdparty/spirv-tools/source/spirv_definition.h b/3rdparty/spirv-tools/source/spirv_definition.h index 9e22108fa..63a4ef0db 100644 --- a/3rdparty/spirv-tools/source/spirv_definition.h +++ b/3rdparty/spirv-tools/source/spirv_definition.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_SPIRV_DEFINITION_H_ -#define LIBSPIRV_SPIRV_DEFINITION_H_ +#ifndef SOURCE_SPIRV_DEFINITION_H_ +#define SOURCE_SPIRV_DEFINITION_H_ #include -#include "latest_version_spirv_header.h" +#include "source/latest_version_spirv_header.h" #define spvIsInBitfield(value, bitfield) ((value) == ((value)&bitfield)) @@ -30,4 +30,4 @@ typedef struct spv_header_t { const uint32_t* instructions; // NOTE: Unfixed pointer to instruciton stream } spv_header_t; -#endif // LIBSPIRV_SPIRV_DEFINITION_H_ +#endif // SOURCE_SPIRV_DEFINITION_H_ diff --git a/3rdparty/spirv-tools/source/spirv_endian.cpp b/3rdparty/spirv-tools/source/spirv_endian.cpp index 56eaac855..1d7709178 100644 --- a/3rdparty/spirv-tools/source/spirv_endian.cpp +++ b/3rdparty/spirv-tools/source/spirv_endian.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "spirv_endian.h" +#include "source/spirv_endian.h" #include diff --git a/3rdparty/spirv-tools/source/spirv_endian.h b/3rdparty/spirv-tools/source/spirv_endian.h index c64b33819..c2540bec9 100644 --- a/3rdparty/spirv-tools/source/spirv_endian.h +++ b/3rdparty/spirv-tools/source/spirv_endian.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_SPIRV_ENDIAN_H_ -#define LIBSPIRV_SPIRV_ENDIAN_H_ +#ifndef SOURCE_SPIRV_ENDIAN_H_ +#define SOURCE_SPIRV_ENDIAN_H_ #include "spirv-tools/libspirv.h" @@ -34,4 +34,4 @@ spv_result_t spvBinaryEndianness(const spv_const_binary binary, // Returns true if the given endianness matches the host's native endiannes. bool spvIsHostEndian(spv_endianness_t endian); -#endif // LIBSPIRV_SPIRV_ENDIAN_H_ +#endif // SOURCE_SPIRV_ENDIAN_H_ diff --git a/3rdparty/spirv-tools/source/spirv_stats.cpp b/3rdparty/spirv-tools/source/spirv_stats.cpp deleted file mode 100644 index ff4b3c67a..000000000 --- a/3rdparty/spirv-tools/source/spirv_stats.cpp +++ /dev/null @@ -1,325 +0,0 @@ -// Copyright (c) 2017 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "spirv_stats.h" - -#include - -#include -#include -#include -#include - -#include "binary.h" -#include "diagnostic.h" -#include "enum_string_mapping.h" -#include "extensions.h" -#include "id_descriptor.h" -#include "instruction.h" -#include "opcode.h" -#include "operand.h" -#include "spirv-tools/libspirv.h" -#include "spirv_endian.h" -#include "spirv_validator_options.h" -#include "val/instruction.h" -#include "val/validation_state.h" -#include "validate.h" - -using libspirv::IdDescriptorCollection; -using libspirv::Instruction; -using libspirv::SpirvStats; -using libspirv::ValidationState_t; - -namespace { - -// Helper class for stats aggregation. Receives as in/out parameter. -// Constructs ValidationState and updates it by running validator for each -// instruction. -class StatsAggregator { - public: - StatsAggregator(SpirvStats* in_out_stats, const spv_const_context context) { - stats_ = in_out_stats; - vstate_.reset(new ValidationState_t(context, &validator_options_)); - } - - // Collects header statistics and sets correct id_bound. - spv_result_t ProcessHeader(spv_endianness_t /* endian */, - uint32_t /* magic */, uint32_t version, - uint32_t generator, uint32_t id_bound, - uint32_t /* schema */) { - vstate_->setIdBound(id_bound); - ++stats_->version_hist[version]; - ++stats_->generator_hist[generator]; - return SPV_SUCCESS; - } - - // Runs validator to validate the instruction and update vstate_, - // then procession the instruction to collect stats. - spv_result_t ProcessInstruction(const spv_parsed_instruction_t* inst) { - const spv_result_t validation_result = - spvtools::ValidateInstructionAndUpdateValidationState(vstate_.get(), - inst); - if (validation_result != SPV_SUCCESS) return validation_result; - - ProcessOpcode(); - ProcessCapability(); - ProcessExtension(); - ProcessConstant(); - ProcessEnums(); - ProcessLiteralStrings(); - ProcessNonIdWords(); - ProcessIdDescriptors(); - - return SPV_SUCCESS; - } - - // Collects statistics of descriptors generated by IdDescriptorCollection. - void ProcessIdDescriptors() { - const Instruction& inst = GetCurrentInstruction(); - const uint32_t new_descriptor = - id_descriptors_.ProcessInstruction(inst.c_inst()); - - if (new_descriptor) { - std::stringstream ss; - ss << spvOpcodeString(inst.opcode()); - for (size_t i = 1; i < inst.words().size(); ++i) { - ss << " " << inst.word(i); - } - stats_->id_descriptor_labels.emplace(new_descriptor, ss.str()); - } - - uint32_t index = 0; - for (const auto& operand : inst.operands()) { - if (spvIsIdType(operand.type)) { - const uint32_t descriptor = - id_descriptors_.GetDescriptor(inst.word(operand.offset)); - if (descriptor) { - ++stats_->id_descriptor_hist[descriptor]; - ++stats_ - ->operand_slot_id_descriptor_hist[std::pair( - inst.opcode(), index)][descriptor]; - } - } - ++index; - } - } - - // Collects statistics of enum words for operands of specific types. - void ProcessEnums() { - const Instruction& inst = GetCurrentInstruction(); - for (const auto& operand : inst.operands()) { - switch (operand.type) { - case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: - case SPV_OPERAND_TYPE_EXECUTION_MODEL: - case SPV_OPERAND_TYPE_ADDRESSING_MODEL: - case SPV_OPERAND_TYPE_MEMORY_MODEL: - case SPV_OPERAND_TYPE_EXECUTION_MODE: - case SPV_OPERAND_TYPE_STORAGE_CLASS: - case SPV_OPERAND_TYPE_DIMENSIONALITY: - case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: - case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: - case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: - case SPV_OPERAND_TYPE_IMAGE_CHANNEL_ORDER: - case SPV_OPERAND_TYPE_IMAGE_CHANNEL_DATA_TYPE: - case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: - case SPV_OPERAND_TYPE_LINKAGE_TYPE: - case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: - case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: - case SPV_OPERAND_TYPE_DECORATION: - case SPV_OPERAND_TYPE_BUILT_IN: - case SPV_OPERAND_TYPE_GROUP_OPERATION: - case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: - case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: - case SPV_OPERAND_TYPE_CAPABILITY: { - ++stats_->enum_hist[operand.type][inst.word(operand.offset)]; - break; - } - default: - break; - } - } - } - - // Collects statistics of literal strings used by opcodes. - void ProcessLiteralStrings() { - const Instruction& inst = GetCurrentInstruction(); - for (const auto& operand : inst.operands()) { - if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) { - const std::string str = - reinterpret_cast(&inst.words()[operand.offset]); - ++stats_->literal_strings_hist[inst.opcode()][str]; - } - } - } - - // Collects statistics of all single word non-id operand slots. - void ProcessNonIdWords() { - const Instruction& inst = GetCurrentInstruction(); - uint32_t index = 0; - for (const auto& operand : inst.operands()) { - if (operand.num_words == 1 && !spvIsIdType(operand.type)) { - ++stats_->operand_slot_non_id_words_hist[std::pair( - inst.opcode(), index)][inst.word(operand.offset)]; - } - ++index; - } - } - - // Collects OpCapability statistics. - void ProcessCapability() { - const Instruction& inst = GetCurrentInstruction(); - if (inst.opcode() != SpvOpCapability) return; - const uint32_t capability = inst.word(inst.operands()[0].offset); - ++stats_->capability_hist[capability]; - } - - // Collects OpExtension statistics. - void ProcessExtension() { - const Instruction& inst = GetCurrentInstruction(); - if (inst.opcode() != SpvOpExtension) return; - const std::string extension = libspirv::GetExtensionString(&inst.c_inst()); - ++stats_->extension_hist[extension]; - } - - // Collects OpCode statistics. - void ProcessOpcode() { - auto inst_it = vstate_->ordered_instructions().rbegin(); - const SpvOp opcode = inst_it->opcode(); - ++stats_->opcode_hist[opcode]; - - const uint32_t opcode_and_num_operands = - (uint32_t(inst_it->operands().size()) << 16) | uint32_t(opcode); - ++stats_->opcode_and_num_operands_hist[opcode_and_num_operands]; - - ++inst_it; - - if (inst_it != vstate_->ordered_instructions().rend()) { - const SpvOp prev_opcode = inst_it->opcode(); - ++stats_->opcode_and_num_operands_markov_hist[prev_opcode] - [opcode_and_num_operands]; - } - - auto step_it = stats_->opcode_markov_hist.begin(); - for (; inst_it != vstate_->ordered_instructions().rend() && - step_it != stats_->opcode_markov_hist.end(); - ++inst_it, ++step_it) { - auto& hist = (*step_it)[inst_it->opcode()]; - ++hist[opcode]; - } - } - - // Collects OpConstant statistics. - void ProcessConstant() { - const Instruction& inst = GetCurrentInstruction(); - if (inst.opcode() != SpvOpConstant) return; - const uint32_t type_id = inst.GetOperandAs(0); - const auto type_decl_it = vstate_->all_definitions().find(type_id); - assert(type_decl_it != vstate_->all_definitions().end()); - const Instruction& type_decl_inst = *type_decl_it->second; - const SpvOp type_op = type_decl_inst.opcode(); - if (type_op == SpvOpTypeInt) { - const uint32_t bit_width = type_decl_inst.GetOperandAs(1); - const uint32_t is_signed = type_decl_inst.GetOperandAs(2); - assert(is_signed == 0 || is_signed == 1); - if (bit_width == 16) { - if (is_signed) - ++stats_->s16_constant_hist[inst.GetOperandAs(2)]; - else - ++stats_->u16_constant_hist[inst.GetOperandAs(2)]; - } else if (bit_width == 32) { - if (is_signed) - ++stats_->s32_constant_hist[inst.GetOperandAs(2)]; - else - ++stats_->u32_constant_hist[inst.GetOperandAs(2)]; - } else if (bit_width == 64) { - if (is_signed) - ++stats_->s64_constant_hist[inst.GetOperandAs(2)]; - else - ++stats_->u64_constant_hist[inst.GetOperandAs(2)]; - } else { - assert(false && "TypeInt bit width is not 16, 32 or 64"); - } - } else if (type_op == SpvOpTypeFloat) { - const uint32_t bit_width = type_decl_inst.GetOperandAs(1); - if (bit_width == 32) { - ++stats_->f32_constant_hist[inst.GetOperandAs(2)]; - } else if (bit_width == 64) { - ++stats_->f64_constant_hist[inst.GetOperandAs(2)]; - } else { - assert(bit_width == 16); - } - } - } - - SpirvStats* stats() { return stats_; } - - private: - // Returns the current instruction (the one last processed by the validator). - const Instruction& GetCurrentInstruction() const { - return vstate_->ordered_instructions().back(); - } - - SpirvStats* stats_; - spv_validator_options_t validator_options_; - std::unique_ptr vstate_; - IdDescriptorCollection id_descriptors_; -}; - -spv_result_t ProcessHeader(void* user_data, spv_endianness_t endian, - uint32_t magic, uint32_t version, uint32_t generator, - uint32_t id_bound, uint32_t schema) { - StatsAggregator* stats_aggregator = - reinterpret_cast(user_data); - return stats_aggregator->ProcessHeader(endian, magic, version, generator, - id_bound, schema); -} - -spv_result_t ProcessInstruction(void* user_data, - const spv_parsed_instruction_t* inst) { - StatsAggregator* stats_aggregator = - reinterpret_cast(user_data); - return stats_aggregator->ProcessInstruction(inst); -} - -} // namespace - -namespace libspirv { - -spv_result_t AggregateStats(const spv_context_t& context, const uint32_t* words, - const size_t num_words, spv_diagnostic* pDiagnostic, - SpirvStats* stats) { - spv_const_binary_t binary = {words, num_words}; - - spv_endianness_t endian; - spv_position_t position = {}; - if (spvBinaryEndianness(&binary, &endian)) { - return libspirv::DiagnosticStream(position, context.consumer, - SPV_ERROR_INVALID_BINARY) - << "Invalid SPIR-V magic number."; - } - - spv_header_t header; - if (spvBinaryHeaderGet(&binary, endian, &header)) { - return libspirv::DiagnosticStream(position, context.consumer, - SPV_ERROR_INVALID_BINARY) - << "Invalid SPIR-V header."; - } - - StatsAggregator stats_aggregator(stats, &context); - - return spvBinaryParse(&context, &stats_aggregator, words, num_words, - ProcessHeader, ProcessInstruction, pDiagnostic); -} - -} // namespace libspirv diff --git a/3rdparty/spirv-tools/source/spirv_target_env.cpp b/3rdparty/spirv-tools/source/spirv_target_env.cpp index cc99228e2..7a11630c7 100644 --- a/3rdparty/spirv-tools/source/spirv_target_env.cpp +++ b/3rdparty/spirv-tools/source/spirv_target_env.cpp @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "source/spirv_target_env.h" + #include #include +#include "source/spirv_constant.h" #include "spirv-tools/libspirv.h" -#include "spirv_constant.h" const char* spvTargetEnvDescription(spv_target_env env) { switch (env) { @@ -58,6 +60,8 @@ const char* spvTargetEnvDescription(spv_target_env env) { return "SPIR-V 1.3"; case SPV_ENV_VULKAN_1_1: return "SPIR-V 1.3 (under Vulkan 1.1 semantics)"; + case SPV_ENV_WEBGPU_0: + return "SPIR-V 1.3 (under WIP WebGPU semantics)"; } assert(0 && "Unhandled SPIR-V target environment"); return ""; @@ -87,6 +91,7 @@ uint32_t spvVersionForTargetEnv(spv_target_env env) { return SPV_SPIRV_VERSION_WORD(1, 2); case SPV_ENV_UNIVERSAL_1_3: case SPV_ENV_VULKAN_1_1: + case SPV_ENV_WEBGPU_0: return SPV_SPIRV_VERSION_WORD(1, 3); } assert(0 && "Unhandled SPIR-V target environment"); @@ -154,6 +159,9 @@ bool spvParseTargetEnv(const char* s, spv_target_env* env) { } else if (match("opengl4.5")) { if (env) *env = SPV_ENV_OPENGL_4_5; return true; + } else if (match("webgpu0")) { + if (env) *env = SPV_ENV_WEBGPU_0; + return true; } else { if (env) *env = SPV_ENV_UNIVERSAL_1_0; return false; @@ -179,6 +187,7 @@ bool spvIsVulkanEnv(spv_target_env env) { case SPV_ENV_OPENCL_2_2: case SPV_ENV_OPENCL_EMBEDDED_2_2: case SPV_ENV_UNIVERSAL_1_3: + case SPV_ENV_WEBGPU_0: return false; case SPV_ENV_VULKAN_1_0: case SPV_ENV_VULKAN_1_1: diff --git a/3rdparty/spirv-tools/source/spirv_target_env.h b/3rdparty/spirv-tools/source/spirv_target_env.h index 315dbbe91..7dc7be1d8 100644 --- a/3rdparty/spirv-tools/source/spirv_target_env.h +++ b/3rdparty/spirv-tools/source/spirv_target_env.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_SPIRV_TARGET_ENV_H_ -#define LIBSPIRV_SPIRV_TARGET_ENV_H_ +#ifndef SOURCE_SPIRV_TARGET_ENV_H_ +#define SOURCE_SPIRV_TARGET_ENV_H_ #include "spirv-tools/libspirv.h" @@ -27,4 +27,4 @@ bool spvIsVulkanEnv(spv_target_env env); // Returns the version number for the given SPIR-V target environment. uint32_t spvVersionForTargetEnv(spv_target_env env); -#endif // LIBSPIRV_SPIRV_TARGET_ENV_H_ +#endif // SOURCE_SPIRV_TARGET_ENV_H_ diff --git a/3rdparty/spirv-tools/source/spirv_validator_options.cpp b/3rdparty/spirv-tools/source/spirv_validator_options.cpp index fe522da48..0c0625364 100644 --- a/3rdparty/spirv-tools/source/spirv_validator_options.cpp +++ b/3rdparty/spirv-tools/source/spirv_validator_options.cpp @@ -15,7 +15,7 @@ #include #include -#include "spirv_validator_options.h" +#include "source/spirv_validator_options.h" bool spvParseUniversalLimitsOptions(const char* s, spv_validator_limit* type) { auto match = [s](const char* b) { @@ -84,5 +84,15 @@ void spvValidatorOptionsSetRelaxStoreStruct(spv_validator_options options, void spvValidatorOptionsSetRelaxLogicalPointer(spv_validator_options options, bool val) { - options->relax_logcial_pointer = val; + options->relax_logical_pointer = val; +} + +void spvValidatorOptionsSetRelaxBlockLayout(spv_validator_options options, + bool val) { + options->relax_block_layout = val; +} + +void spvValidatorOptionsSetSkipBlockLayout(spv_validator_options options, + bool val) { + options->skip_block_layout = val; } diff --git a/3rdparty/spirv-tools/source/spirv_validator_options.h b/3rdparty/spirv-tools/source/spirv_validator_options.h index d15b63bb0..d264a7e0b 100644 --- a/3rdparty/spirv-tools/source/spirv_validator_options.h +++ b/3rdparty/spirv-tools/source/spirv_validator_options.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_SPIRV_VALIDATOR_OPTIONS_H_ -#define LIBSPIRV_SPIRV_VALIDATOR_OPTIONS_H_ +#ifndef SOURCE_SPIRV_VALIDATOR_OPTIONS_H_ +#define SOURCE_SPIRV_VALIDATOR_OPTIONS_H_ #include "spirv-tools/libspirv.h" @@ -40,11 +40,15 @@ struct spv_validator_options_t { spv_validator_options_t() : universal_limits_(), relax_struct_store(false), - relax_logcial_pointer(false) {} + relax_logical_pointer(false), + relax_block_layout(false), + skip_block_layout(false) {} validator_universal_limits_t universal_limits_; bool relax_struct_store; - bool relax_logcial_pointer; + bool relax_logical_pointer; + bool relax_block_layout; + bool skip_block_layout; }; -#endif // LIBSPIRV_SPIRV_VALIDATOR_OPTIONS_H_ +#endif // SOURCE_SPIRV_VALIDATOR_OPTIONS_H_ diff --git a/3rdparty/spirv-tools/source/table.cpp b/3rdparty/spirv-tools/source/table.cpp index 1a40e27c6..b10d776da 100644 --- a/3rdparty/spirv-tools/source/table.cpp +++ b/3rdparty/spirv-tools/source/table.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "table.h" +#include "source/table.h" #include @@ -37,6 +37,7 @@ spv_context spvContextCreate(spv_target_env env) { case SPV_ENV_UNIVERSAL_1_2: case SPV_ENV_UNIVERSAL_1_3: case SPV_ENV_VULKAN_1_1: + case SPV_ENV_WEBGPU_0: break; default: return nullptr; @@ -56,7 +57,7 @@ spv_context spvContextCreate(spv_target_env env) { void spvContextDestroy(spv_context context) { delete context; } -void libspirv::SetContextMessageConsumer(spv_context context, +void spvtools::SetContextMessageConsumer(spv_context context, spvtools::MessageConsumer consumer) { context->consumer = std::move(consumer); } diff --git a/3rdparty/spirv-tools/source/table.h b/3rdparty/spirv-tools/source/table.h index d4e58348c..64d73dbb9 100644 --- a/3rdparty/spirv-tools/source/table.h +++ b/3rdparty/spirv-tools/source/table.h @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_TABLE_H_ -#define LIBSPIRV_TABLE_H_ +#ifndef SOURCE_TABLE_H_ +#define SOURCE_TABLE_H_ -#include "latest_version_spirv_header.h" +#include "source/latest_version_spirv_header.h" -#include "extensions.h" -#include "message.h" +#include "source/extensions.h" #include "spirv-tools/libspirv.hpp" typedef struct spv_opcode_desc_t { @@ -38,7 +37,7 @@ typedef struct spv_opcode_desc_t { // assembler, binary parser, and disassembler ignore this rule, so you can // freely process invalid modules. const uint32_t numExtensions; - const libspirv::Extension* extensions; + const spvtools::Extension* extensions; // Minimal core SPIR-V version required for this feature, if without // extensions. ~0u means reserved for future use. ~0u and non-empty extension // lists means only available in extensions. @@ -55,7 +54,7 @@ typedef struct spv_operand_desc_t { // assembler, binary parser, and disassembler ignore this rule, so you can // freely process invalid modules. const uint32_t numExtensions; - const libspirv::Extension* extensions; + const spvtools::Extension* extensions; const spv_operand_type_t operandTypes[16]; // TODO: Smaller/larger? // Minimal core SPIR-V version required for this feature, if without // extensions. ~0u means reserved for future use. ~0u and non-empty extension @@ -114,12 +113,12 @@ struct spv_context_t { spvtools::MessageConsumer consumer; }; -namespace libspirv { +namespace spvtools { + // Sets the message consumer to |consumer| in the given |context|. The original // message consumer will be overwritten. -void SetContextMessageConsumer(spv_context context, - spvtools::MessageConsumer consumer); -} // namespace libspirv +void SetContextMessageConsumer(spv_context context, MessageConsumer consumer); +} // namespace spvtools // Populates *table with entries for env. spv_result_t spvOpcodeTableGet(spv_opcode_table* table, spv_target_env env); @@ -130,4 +129,4 @@ spv_result_t spvOperandTableGet(spv_operand_table* table, spv_target_env env); // Populates *table with entries for env. spv_result_t spvExtInstTableGet(spv_ext_inst_table* table, spv_target_env env); -#endif // LIBSPIRV_TABLE_H_ +#endif // SOURCE_TABLE_H_ diff --git a/3rdparty/spirv-tools/source/text.cpp b/3rdparty/spirv-tools/source/text.cpp index ac4f8a1f0..adaf79652 100644 --- a/3rdparty/spirv-tools/source/text.cpp +++ b/3rdparty/spirv-tools/source/text.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "text.h" +#include "source/text.h" #include #include @@ -21,26 +21,27 @@ #include #include #include +#include #include #include #include +#include #include -#include "assembly_grammar.h" -#include "binary.h" -#include "diagnostic.h" -#include "ext_inst.h" -#include "instruction.h" -#include "message.h" -#include "opcode.h" -#include "operand.h" +#include "source/assembly_grammar.h" +#include "source/binary.h" +#include "source/diagnostic.h" +#include "source/ext_inst.h" +#include "source/instruction.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/spirv_constant.h" +#include "source/spirv_target_env.h" +#include "source/table.h" +#include "source/text_handler.h" +#include "source/util/bitutils.h" +#include "source/util/parse_number.h" #include "spirv-tools/libspirv.h" -#include "spirv_constant.h" -#include "spirv_target_env.h" -#include "table.h" -#include "text_handler.h" -#include "util/bitutils.h" -#include "util/parse_number.h" bool spvIsValidIDCharacter(const char value) { return value == '_' || 0 != ::isalnum(value); @@ -158,11 +159,11 @@ namespace { /// successful, adds the parsed value to pInst, advances the context past it, /// and returns SPV_SUCCESS. Otherwise, leaves pInst alone, emits diagnostics, /// and returns SPV_ERROR_INVALID_TEXT. -spv_result_t encodeImmediate(libspirv::AssemblyContext* context, +spv_result_t encodeImmediate(spvtools::AssemblyContext* context, const char* text, spv_instruction_t* pInst) { assert(*text == '!'); uint32_t parse_result; - if (!spvutils::ParseNumber(text + 1, &parse_result)) { + if (!spvtools::utils::ParseNumber(text + 1, &parse_result)) { return context->diagnostic(SPV_ERROR_INVALID_TEXT) << "Invalid immediate integer: !" << text + 1; } @@ -183,8 +184,8 @@ spv_result_t encodeImmediate(libspirv::AssemblyContext* context, /// @param[in,out] pExpectedOperands the operand types expected /// /// @return result code -spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar& grammar, - libspirv::AssemblyContext* context, +spv_result_t spvTextEncodeOperand(const spvtools::AssemblyGrammar& grammar, + spvtools::AssemblyContext* context, const spv_operand_type_t type, const char* textValue, spv_instruction_t* pInst, @@ -279,8 +280,8 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar& grammar, case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: { // The current operand is an *unsigned* 32-bit integer. // That's just how the grammar works. - libspirv::IdType expected_type = { - 32, false, libspirv::IdTypeClass::kScalarIntegerType}; + spvtools::IdType expected_type = { + 32, false, spvtools::IdTypeClass::kScalarIntegerType}; if (auto error = context->binaryEncodeNumericLiteral( textValue, error_code_for_literals, expected_type, pInst)) { return error; @@ -291,7 +292,7 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar& grammar, // This is a context-independent literal number which can be a 32-bit // number of floating point value. if (auto error = context->binaryEncodeNumericLiteral( - textValue, error_code_for_literals, libspirv::kUnknownType, + textValue, error_code_for_literals, spvtools::kUnknownType, pInst)) { return error; } @@ -299,7 +300,7 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar& grammar, case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: { - libspirv::IdType expected_type = libspirv::kUnknownType; + spvtools::IdType expected_type = spvtools::kUnknownType; // The encoding for OpConstant, OpSpecConstant and OpSwitch all // depend on either their own result-id or the result-id of // one of their parameters. @@ -309,8 +310,8 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar& grammar, // instruction. expected_type = context->getTypeOfTypeGeneratingValue(pInst->resultTypeId); - if (!libspirv::isScalarFloating(expected_type) && - !libspirv::isScalarIntegral(expected_type)) { + if (!spvtools::isScalarFloating(expected_type) && + !spvtools::isScalarIntegral(expected_type)) { spv_opcode_desc d; const char* opcode_name = "opcode"; if (SPV_SUCCESS == grammar.lookupOpcode(pInst->opcode, &d)) { @@ -323,7 +324,7 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar& grammar, } else if (pInst->opcode == SpvOpSwitch) { // The type of the literal is the same as the type of the selector. expected_type = context->getTypeOfValueInstruction(pInst->words[1]); - if (!libspirv::isScalarIntegral(expected_type)) { + if (!spvtools::isScalarIntegral(expected_type)) { return context->diagnostic() << "The selector operand for OpSwitch must be the result" " of an instruction that generates an integer scalar"; @@ -438,8 +439,8 @@ namespace { /// instruction and returns SPV_SUCCESS. Otherwise, returns an error code and /// leaves position pointing to the error in text. spv_result_t encodeInstructionStartingWithImmediate( - const libspirv::AssemblyGrammar& grammar, - libspirv::AssemblyContext* context, spv_instruction_t* pInst) { + const spvtools::AssemblyGrammar& grammar, + spvtools::AssemblyContext* context, spv_instruction_t* pInst) { std::string firstWord; spv_position_t nextPosition = {}; auto error = context->getWord(&firstWord, &nextPosition); @@ -482,8 +483,8 @@ spv_result_t encodeInstructionStartingWithImmediate( /// @param[in,out] pPosition in the text stream /// /// @return result code -spv_result_t spvTextEncodeOpcode(const libspirv::AssemblyGrammar& grammar, - libspirv::AssemblyContext* context, +spv_result_t spvTextEncodeOpcode(const spvtools::AssemblyGrammar& grammar, + spvtools::AssemblyContext* context, spv_instruction_t* pInst) { // Check for ! first. if ('!' == context->peek()) { @@ -667,11 +668,11 @@ spv_result_t SetHeader(spv_target_env env, const uint32_t bound, // Collects all numeric ids in the module source into |numeric_ids|. // This function is essentially a dry-run of spvTextToBinary. -spv_result_t GetNumericIds(const libspirv::AssemblyGrammar& grammar, +spv_result_t GetNumericIds(const spvtools::AssemblyGrammar& grammar, const spvtools::MessageConsumer& consumer, const spv_text text, std::set* numeric_ids) { - libspirv::AssemblyContext context(text, consumer); + spvtools::AssemblyContext context(text, consumer); if (!text->str) return context.diagnostic() << "Missing assembly text."; @@ -699,7 +700,7 @@ spv_result_t GetNumericIds(const libspirv::AssemblyGrammar& grammar, // Translates a given assembly language module into binary form. // If a diagnostic is generated, it is not yet marked as being // for a text-based input. -spv_result_t spvTextToBinaryInternal(const libspirv::AssemblyGrammar& grammar, +spv_result_t spvTextToBinaryInternal(const spvtools::AssemblyGrammar& grammar, const spvtools::MessageConsumer& consumer, const spv_text text, const uint32_t options, @@ -715,7 +716,7 @@ spv_result_t spvTextToBinaryInternal(const libspirv::AssemblyGrammar& grammar, if (result != SPV_SUCCESS) return result; } - libspirv::AssemblyContext context(text, consumer, std::move(ids_to_preserve)); + spvtools::AssemblyContext context(text, consumer, std::move(ids_to_preserve)); if (!text->str) return context.diagnostic() << "Missing assembly text."; @@ -790,11 +791,11 @@ spv_result_t spvTextToBinaryWithOptions(const spv_const_context context, spv_context_t hijack_context = *context; if (pDiagnostic) { *pDiagnostic = nullptr; - libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); } spv_text_t text = {input_text, input_text_size}; - libspirv::AssemblyGrammar grammar(&hijack_context); + spvtools::AssemblyGrammar grammar(&hijack_context); spv_result_t result = spvTextToBinaryInternal( grammar, hijack_context.consumer, &text, options, pBinary); diff --git a/3rdparty/spirv-tools/source/text.h b/3rdparty/spirv-tools/source/text.h index 19990cc93..fa34ee16b 100644 --- a/3rdparty/spirv-tools/source/text.h +++ b/3rdparty/spirv-tools/source/text.h @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_TEXT_H_ -#define LIBSPIRV_TEXT_H_ +#ifndef SOURCE_TEXT_H_ +#define SOURCE_TEXT_H_ #include -#include "operand.h" +#include "source/operand.h" +#include "source/spirv_constant.h" #include "spirv-tools/libspirv.h" -#include "spirv_constant.h" typedef enum spv_literal_type_t { SPV_LITERAL_TYPE_INT_32, @@ -50,4 +50,4 @@ typedef struct spv_literal_t { // which are then stripped. spv_result_t spvTextToLiteral(const char* text, spv_literal_t* literal); -#endif // LIBSPIRV_TEXT_H_ +#endif // SOURCE_TEXT_H_ diff --git a/3rdparty/spirv-tools/source/text_handler.cpp b/3rdparty/spirv-tools/source/text_handler.cpp index 1a1b48d04..5f6e8c41f 100644 --- a/3rdparty/spirv-tools/source/text_handler.cpp +++ b/3rdparty/spirv-tools/source/text_handler.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "text_handler.h" +#include "source/text_handler.h" #include #include @@ -20,17 +20,19 @@ #include #include -#include "assembly_grammar.h" -#include "binary.h" -#include "ext_inst.h" -#include "instruction.h" -#include "opcode.h" -#include "text.h" -#include "util/bitutils.h" -#include "util/hex_float.h" -#include "util/parse_number.h" +#include "source/assembly_grammar.h" +#include "source/binary.h" +#include "source/ext_inst.h" +#include "source/instruction.h" +#include "source/opcode.h" +#include "source/text.h" +#include "source/util/bitutils.h" +#include "source/util/hex_float.h" +#include "source/util/parse_number.h" +namespace spvtools { namespace { + // Advances |text| to the start of the next line and writes the new position to // |position|. spv_result_t advanceLine(spv_text text, spv_position position) { @@ -105,9 +107,9 @@ spv_result_t getWord(spv_text text, spv_position position, std::string* word) { return SPV_SUCCESS; } const char ch = text->str[position->index]; - if (ch == '\\') + if (ch == '\\') { escaping = !escaping; - else { + } else { switch (ch) { case '"': if (!escaping) quoting = !quoting; @@ -144,9 +146,7 @@ bool startsWithOp(spv_text text, spv_position position) { return ('O' == ch0 && 'p' == ch1 && ('A' <= ch2 && ch2 <= 'Z')); } -} // anonymous namespace - -namespace libspirv { +} // namespace const IdType kUnknownType = {0, false, IdTypeClass::kBottom}; @@ -157,7 +157,7 @@ const IdType kUnknownType = {0, false, IdTypeClass::kBottom}; uint32_t AssemblyContext::spvNamedIdAssignOrGet(const char* textValue) { if (!ids_to_preserve_.empty()) { uint32_t id = 0; - if (spvutils::ParseNumber(textValue, &id)) { + if (spvtools::utils::ParseNumber(textValue, &id)) { if (ids_to_preserve_.find(id) != ids_to_preserve_.end()) { bound_ = std::max(bound_, id + 1); return id; @@ -185,35 +185,35 @@ uint32_t AssemblyContext::spvNamedIdAssignOrGet(const char* textValue) { uint32_t AssemblyContext::getBound() const { return bound_; } spv_result_t AssemblyContext::advance() { - return ::advance(text_, ¤t_position_); + return spvtools::advance(text_, ¤t_position_); } spv_result_t AssemblyContext::getWord(std::string* word, spv_position next_position) { *next_position = current_position_; - return ::getWord(text_, next_position, word); + return spvtools::getWord(text_, next_position, word); } bool AssemblyContext::startsWithOp() { - return ::startsWithOp(text_, ¤t_position_); + return spvtools::startsWithOp(text_, ¤t_position_); } bool AssemblyContext::isStartOfNewInst() { spv_position_t pos = current_position_; - if (::advance(text_, &pos)) return false; - if (::startsWithOp(text_, &pos)) return true; + if (spvtools::advance(text_, &pos)) return false; + if (spvtools::startsWithOp(text_, &pos)) return true; std::string word; pos = current_position_; - if (::getWord(text_, &pos, &word)) return false; + if (spvtools::getWord(text_, &pos, &word)) return false; if ('%' != word.front()) return false; - if (::advance(text_, &pos)) return false; - if (::getWord(text_, &pos, &word)) return false; + if (spvtools::advance(text_, &pos)) return false; + if (spvtools::getWord(text_, &pos, &word)) return false; if ("=" != word) return false; - if (::advance(text_, &pos)) return false; - if (::startsWithOp(text_, &pos)) return true; + if (spvtools::advance(text_, &pos)) return false; + if (spvtools::startsWithOp(text_, &pos)) return true; return false; } @@ -239,9 +239,9 @@ spv_result_t AssemblyContext::binaryEncodeU32(const uint32_t value, spv_result_t AssemblyContext::binaryEncodeNumericLiteral( const char* val, spv_result_t error_code, const IdType& type, spv_instruction_t* pInst) { - using spvutils::EncodeNumberStatus; + using spvtools::utils::EncodeNumberStatus; // Populate the NumberType from the IdType for parsing. - spvutils::NumberType number_type; + spvtools::utils::NumberType number_type; switch (type.type_class) { case IdTypeClass::kOtherType: return diagnostic(SPV_ERROR_INTERNAL) @@ -389,9 +389,9 @@ std::set AssemblyContext::GetNumericIds() const { std::set ids; for (const auto& kv : named_ids_) { uint32_t id; - if (spvutils::ParseNumber(kv.first.c_str(), &id)) ids.insert(id); + if (spvtools::utils::ParseNumber(kv.first.c_str(), &id)) ids.insert(id); } return ids; } -} // namespace libspirv +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/text_handler.h b/3rdparty/spirv-tools/source/text_handler.h index e49b51bbd..19972e951 100644 --- a/3rdparty/spirv-tools/source/text_handler.h +++ b/3rdparty/spirv-tools/source/text_handler.h @@ -12,21 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_TEXT_HANDLER_H_ -#define LIBSPIRV_TEXT_HANDLER_H_ +#ifndef SOURCE_TEXT_HANDLER_H_ +#define SOURCE_TEXT_HANDLER_H_ #include +#include #include +#include #include #include +#include -#include "diagnostic.h" -#include "instruction.h" -#include "message.h" +#include "source/diagnostic.h" +#include "source/instruction.h" +#include "source/text.h" #include "spirv-tools/libspirv.h" -#include "text.h" -namespace libspirv { +namespace spvtools { + // Structures // This is a lattice for tracking types. @@ -117,7 +120,7 @@ class ClampToZeroIfUnsignedType< // Encapsulates the data used during the assembly of a SPIR-V module. class AssemblyContext { public: - AssemblyContext(spv_text text, const spvtools::MessageConsumer& consumer, + AssemblyContext(spv_text text, const MessageConsumer& consumer, std::set&& ids_to_preserve = std::set()) : current_position_({}), consumer_(consumer), @@ -152,7 +155,7 @@ class AssemblyContext { // stream, and for the given error code. Any data written to this object will // show up in pDiagnsotic on destruction. DiagnosticStream diagnostic(spv_result_t error) { - return DiagnosticStream(current_position_, consumer_, error); + return DiagnosticStream(current_position_, consumer_, "", error); } // Returns a diagnostic object with the default assembly error code. @@ -249,11 +252,13 @@ class AssemblyContext { // Maps an extended instruction import Id to the extended instruction type. std::unordered_map import_id_to_ext_inst_type_; spv_position_t current_position_; - spvtools::MessageConsumer consumer_; + MessageConsumer consumer_; spv_text text_; uint32_t bound_; uint32_t next_id_; std::set ids_to_preserve_; }; -} // namespace libspirv -#endif // _LIBSPIRV_TEXT_HANDLER_H_ + +} // namespace spvtools + +#endif // SOURCE_TEXT_HANDLER_H_ diff --git a/3rdparty/spirv-tools/source/util/bit_vector.cpp b/3rdparty/spirv-tools/source/util/bit_vector.cpp new file mode 100644 index 000000000..47e275bf4 --- /dev/null +++ b/3rdparty/spirv-tools/source/util/bit_vector.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/util/bit_vector.h" + +#include +#include + +namespace spvtools { +namespace utils { + +void BitVector::ReportDensity(std::ostream& out) { + uint32_t count = 0; + + for (BitContainer e : bits_) { + while (e != 0) { + if ((e & 1) != 0) { + ++count; + } + e = e >> 1; + } + } + + out << "count=" << count + << ", total size (bytes)=" << bits_.size() * sizeof(BitContainer) + << ", bytes per element=" + << (double)(bits_.size() * sizeof(BitContainer)) / (double)(count); +} + +bool BitVector::Or(const BitVector& other) { + auto this_it = this->bits_.begin(); + auto other_it = other.bits_.begin(); + bool modified = false; + + while (this_it != this->bits_.end() && other_it != other.bits_.end()) { + auto temp = *this_it | *other_it; + if (temp != *this_it) { + modified = true; + *this_it = temp; + } + ++this_it; + ++other_it; + } + + if (other_it != other.bits_.end()) { + modified = true; + this->bits_.insert(this->bits_.end(), other_it, other.bits_.end()); + } + + return modified; +} + +std::ostream& operator<<(std::ostream& out, const BitVector& bv) { + out << "{"; + for (uint32_t i = 0; i < bv.bits_.size(); ++i) { + BitVector::BitContainer b = bv.bits_[i]; + uint32_t j = 0; + while (b != 0) { + if (b & 1) { + out << ' ' << i * BitVector::kBitContainerSize + j; + } + ++j; + b = b >> 1; + } + } + out << "}"; + return out; +} + +} // namespace utils +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/util/bit_vector.h b/3rdparty/spirv-tools/source/util/bit_vector.h new file mode 100644 index 000000000..3e189cb10 --- /dev/null +++ b/3rdparty/spirv-tools/source/util/bit_vector.h @@ -0,0 +1,119 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_UTIL_BIT_VECTOR_H_ +#define SOURCE_UTIL_BIT_VECTOR_H_ + +#include +#include +#include + +namespace spvtools { +namespace utils { + +// Implements a bit vector class. +// +// All bits default to zero, and the upper bound is 2^32-1. +class BitVector { + private: + using BitContainer = uint64_t; + enum { kBitContainerSize = 64 }; + enum { kInitialNumBits = 1024 }; + + public: + // Creates a bit vector contianing 0s. + BitVector(uint32_t reserved_size = kInitialNumBits) + : bits_((reserved_size - 1) / kBitContainerSize + 1, 0) {} + + // Sets the |i|th bit to 1. Returns the |i|th bit before it was set. + bool Set(uint32_t i) { + uint32_t element_index = i / kBitContainerSize; + uint32_t bit_in_element = i % kBitContainerSize; + + if (element_index >= bits_.size()) { + bits_.resize(element_index + 1, 0); + } + + BitContainer original = bits_[element_index]; + BitContainer ith_bit = static_cast(1) << bit_in_element; + + if ((original & ith_bit) != 0) { + return true; + } else { + bits_[element_index] = original | ith_bit; + return false; + } + } + + // Sets the |i|th bit to 0. Return the |i|th bit before it was cleared. + bool Clear(uint32_t i) { + uint32_t element_index = i / kBitContainerSize; + uint32_t bit_in_element = i % kBitContainerSize; + + if (element_index >= bits_.size()) { + return false; + } + + BitContainer original = bits_[element_index]; + BitContainer ith_bit = static_cast(1) << bit_in_element; + + if ((original & ith_bit) == 0) { + return false; + } else { + bits_[element_index] = original & (~ith_bit); + return true; + } + } + + // Returns the |i|th bit. + bool Get(uint32_t i) const { + uint32_t element_index = i / kBitContainerSize; + uint32_t bit_in_element = i % kBitContainerSize; + + if (element_index >= bits_.size()) { + return false; + } + + return (bits_[element_index] & + (static_cast(1) << bit_in_element)) != 0; + } + + // Returns true if every bit is 0. + bool Empty() const { + for (BitContainer b : bits_) { + if (b != 0) { + return false; + } + } + return true; + } + + // Print a report on the densicy of the bit vector, number of 1 bits, number + // of bytes, and average bytes for 1 bit, to |out|. + void ReportDensity(std::ostream& out); + + friend std::ostream& operator<<(std::ostream&, const BitVector&); + + // Performs a bitwise-or operation on |this| and |that|, storing the result in + // |this|. Return true if |this| changed. + bool Or(const BitVector& that); + + private: + std::vector bits_; +}; + +} // namespace utils +} // namespace spvtools + +#endif // SOURCE_UTIL_BIT_VECTOR_H_ diff --git a/3rdparty/spirv-tools/source/util/bitutils.h b/3rdparty/spirv-tools/source/util/bitutils.h index 9b53d3b2d..17d61df90 100644 --- a/3rdparty/spirv-tools/source/util/bitutils.h +++ b/3rdparty/spirv-tools/source/util/bitutils.h @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_UTIL_BITUTILS_H_ -#define LIBSPIRV_UTIL_BITUTILS_H_ +#ifndef SOURCE_UTIL_BITUTILS_H_ +#define SOURCE_UTIL_BITUTILS_H_ #include #include -namespace spvutils { +namespace spvtools { +namespace utils { // Performs a bitwise copy of source to the destination type Dest. template @@ -89,6 +90,7 @@ size_t CountSetBits(T word) { return count; } -} // namespace spvutils +} // namespace utils +} // namespace spvtools -#endif // LIBSPIRV_UTIL_BITUTILS_H_ +#endif // SOURCE_UTIL_BITUTILS_H_ diff --git a/3rdparty/spirv-tools/source/util/hex_float.h b/3rdparty/spirv-tools/source/util/hex_float.h index de99cc356..b7baf093b 100644 --- a/3rdparty/spirv-tools/source/util/hex_float.h +++ b/3rdparty/spirv-tools/source/util/hex_float.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_UTIL_HEX_FLOAT_H_ -#define LIBSPIRV_UTIL_HEX_FLOAT_H_ +#ifndef SOURCE_UTIL_HEX_FLOAT_H_ +#define SOURCE_UTIL_HEX_FLOAT_H_ #include #include @@ -24,7 +24,7 @@ #include #include -#include "bitutils.h" +#include "source/util/bitutils.h" #ifndef __GNUC__ #define GCC_VERSION 0 @@ -33,7 +33,8 @@ (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) #endif -namespace spvutils { +namespace spvtools { +namespace utils { class Float16 { public: @@ -317,8 +318,7 @@ class HexFloat { // The representation of the fraction, not the actual bits. This // includes the leading bit that is usually implicit. static const uint_type fraction_represent_mask = - spvutils::SetBits::get; + SetBits::get; // The topmost bit in the nibble-aligned fraction. static const uint_type fraction_top_bit = @@ -332,14 +332,14 @@ class HexFloat { // The mask for the encoded fraction. It does not include the // implicit bit. static const uint_type fraction_encode_mask = - spvutils::SetBits::get; + SetBits::get; // The bit that is used as a sign. static const uint_type sign_mask = uint_type(1) << top_bit_left_shift; // The bits that represent the exponent. static const uint_type exponent_mask = - spvutils::SetBits::get; + SetBits::get; // How far left the exponent is shifted. static const uint32_t exponent_left_shift = num_fraction_bits; @@ -568,7 +568,7 @@ class HexFloat { static const uint_type throwaway_mask_bits = num_throwaway_bits > 0 ? num_throwaway_bits : 0; static const uint_type throwaway_mask = - spvutils::SetBits::get; + SetBits::get; *carry_bit = false; other_uint_type out_val = 0; @@ -1143,6 +1143,8 @@ inline std::ostream& operator<<(std::ostream& os, os << HexFloat>(value); return os; } -} // namespace spvutils -#endif // LIBSPIRV_UTIL_HEX_FLOAT_H_ +} // namespace utils +} // namespace spvtools + +#endif // SOURCE_UTIL_HEX_FLOAT_H_ diff --git a/3rdparty/spirv-tools/source/util/ilist.h b/3rdparty/spirv-tools/source/util/ilist.h index c4287d248..9837b09b3 100644 --- a/3rdparty/spirv-tools/source/util/ilist.h +++ b/3rdparty/spirv-tools/source/util/ilist.h @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_ILIST_H_ -#define LIBSPIRV_OPT_ILIST_H_ +#ifndef SOURCE_UTIL_ILIST_H_ +#define SOURCE_UTIL_ILIST_H_ #include #include #include #include -#include "ilist_node.h" +#include "source/util/ilist_node.h" namespace spvtools { namespace utils { @@ -362,4 +362,4 @@ void IntrusiveList::Check(NodeType* start) { } // namespace utils } // namespace spvtools -#endif // LIBSPIRV_OPT_ILIST_H_ +#endif // SOURCE_UTIL_ILIST_H_ diff --git a/3rdparty/spirv-tools/source/util/ilist_node.h b/3rdparty/spirv-tools/source/util/ilist_node.h index 76ea302d7..0579534b8 100644 --- a/3rdparty/spirv-tools/source/util/ilist_node.h +++ b/3rdparty/spirv-tools/source/util/ilist_node.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_ILIST_NODE_H_ -#define LIBSPIRV_OPT_ILIST_NODE_H_ +#ifndef SOURCE_UTIL_ILIST_NODE_H_ +#define SOURCE_UTIL_ILIST_NODE_H_ #include @@ -67,7 +67,7 @@ class IntrusiveNodeBase { // from that list. // // It is assumed that the given node is of type NodeType. It is an error if - // |pos| is not already in a list. + // |pos| is not already in a list, or if |pos| is equal to |this|. inline void InsertAfter(NodeType* pos); // Removes the given node from the list. It is assumed that the node is @@ -185,6 +185,8 @@ template inline void IntrusiveNodeBase::InsertAfter(NodeType* pos) { assert(!this->is_sentinel_ && "Sentinel nodes cannot be moved around."); assert(pos->IsInAList() && "Pos should already be in a list."); + assert(this != pos && "Can't insert a node after itself."); + if (this->IsInAList()) { this->RemoveFromList(); } @@ -260,4 +262,4 @@ bool IntrusiveNodeBase::IsEmptyList() { } // namespace utils } // namespace spvtools -#endif // LIBSPIRV_OPT_ILIST_NODE_H_ +#endif // SOURCE_UTIL_ILIST_NODE_H_ diff --git a/3rdparty/spirv-tools/source/opt/make_unique.h b/3rdparty/spirv-tools/source/util/make_unique.h similarity index 88% rename from 3rdparty/spirv-tools/source/opt/make_unique.h rename to 3rdparty/spirv-tools/source/util/make_unique.h index f6b8aae28..ad7976c34 100644 --- a/3rdparty/spirv-tools/source/opt/make_unique.h +++ b/3rdparty/spirv-tools/source/util/make_unique.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_MAKE_UNIQUE_H_ -#define LIBSPIRV_OPT_MAKE_UNIQUE_H_ +#ifndef SOURCE_UTIL_MAKE_UNIQUE_H_ +#define SOURCE_UTIL_MAKE_UNIQUE_H_ #include #include @@ -27,4 +27,4 @@ std::unique_ptr MakeUnique(Args&&... args) { } // namespace spvtools -#endif // LIBSPIRV_OPT_MAKE_UNIQUE_H_ +#endif // SOURCE_UTIL_MAKE_UNIQUE_H_ diff --git a/3rdparty/spirv-tools/source/util/move_to_front.h b/3rdparty/spirv-tools/source/util/move_to_front.h deleted file mode 100644 index de405dde4..000000000 --- a/3rdparty/spirv-tools/source/util/move_to_front.h +++ /dev/null @@ -1,825 +0,0 @@ -// Copyright (c) 2017 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef LIBSPIRV_UTIL_MOVE_TO_FRONT_H_ -#define LIBSPIRV_UTIL_MOVE_TO_FRONT_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace spvutils { - -// Log(n) move-to-front implementation. Implements the following functions: -// Insert - pushes value to the front of the mtf sequence -// (only unique values allowed). -// Remove - remove value from the sequence. -// ValueFromRank - access value by its 1-indexed rank in the sequence. -// RankFromValue - get the rank of the given value in the sequence. -// Accessing a value with ValueFromRank or RankFromValue moves the value to the -// front of the sequence (rank of 1). -// -// The implementation is based on an AVL-based order statistic tree. The tree -// is ordered by timestamps issued when values are inserted or accessed (recent -// values go to the left side of the tree, old values are gradually rotated to -// the right side). -// -// Terminology -// rank: 1-indexed rank showing how recently the value was inserted or accessed. -// node: handle used internally to access node data. -// size: size of the subtree of a node (including the node). -// height: distance from a node to the farthest leaf. -template -class MoveToFront { - public: - explicit MoveToFront(size_t reserve_capacity = 4) { - nodes_.reserve(reserve_capacity); - - // Create NIL node. - nodes_.emplace_back(Node()); - } - - virtual ~MoveToFront() {} - - // Inserts value in the move-to-front sequence. Does nothing if the value is - // already in the sequence. Returns true if insertion was successful. - // The inserted value is placed at the front of the sequence (rank 1). - bool Insert(const Val& value); - - // Removes value from move-to-front sequence. Returns false iff the value - // was not found. - bool Remove(const Val& value); - - // Computes 1-indexed rank of value in the move-to-front sequence and moves - // the value to the front. Example: - // Before the call: 4 8 2 1 7 - // RankFromValue(8) returns 2 - // After the call: 8 4 2 1 7 - // Returns true iff the value was found in the sequence. - bool RankFromValue(const Val& value, uint32_t* rank); - - // Returns value corresponding to a 1-indexed rank in the move-to-front - // sequence and moves the value to the front. Example: - // Before the call: 4 8 2 1 7 - // ValueFromRank(2) returns 8 - // After the call: 8 4 2 1 7 - // Returns true iff the rank is within bounds [1, GetSize()]. - bool ValueFromRank(uint32_t rank, Val* value); - - // Moves the value to the front of the sequence. - // Returns false iff value is not in the sequence. - bool Promote(const Val& value); - - // Returns true iff the move-to-front sequence contains the value. - bool HasValue(const Val& value) const; - - // Returns the number of elements in the move-to-front sequence. - uint32_t GetSize() const { return SizeOf(root_); } - - protected: - // Internal tree data structure uses handles instead of pointers. Leaves and - // root parent reference a singleton under handle 0. Although dereferencing - // a null pointer is not possible, inappropriate access to handle 0 would - // cause an assertion. Handles are not garbage collected if value was - // deprecated - // with DeprecateValue(). But handles are recycled when a node is - // repositioned. - - // Internal tree data structure node. - struct Node { - // Timestamp from a logical clock which updates every time the element is - // accessed through ValueFromRank or RankFromValue. - uint32_t timestamp = 0; - // The size of the node's subtree, including the node. - // SizeOf(LeftOf(node)) + SizeOf(RightOf(node)) + 1. - uint32_t size = 0; - // Handles to connected nodes. - uint32_t left = 0; - uint32_t right = 0; - uint32_t parent = 0; - // Distance to the farthest leaf. - // Leaves have height 0, real nodes at least 1. - uint32_t height = 0; - // Stored value. - Val value = Val(); - }; - - // Creates node and sets correct values. Non-NIL nodes should be created only - // through this function. If the node with this value has been created - // previously - // and since orphaned, reuses the old node instead of creating a new one. - uint32_t CreateNode(uint32_t timestamp, const Val& value) { - uint32_t handle = static_cast(nodes_.size()); - const auto result = value_to_node_.emplace(value, handle); - if (result.second) { - // Create new node. - nodes_.emplace_back(Node()); - Node& node = nodes_.back(); - node.timestamp = timestamp; - node.value = value; - node.size = 1; - // Non-NIL nodes start with height 1 because their NIL children are - // leaves. - node.height = 1; - } else { - // Reuse old node. - handle = result.first->second; - assert(!IsInTree(handle)); - assert(ValueOf(handle) == value); - assert(SizeOf(handle) == 1); - assert(HeightOf(handle) == 1); - MutableTimestampOf(handle) = timestamp; - } - - return handle; - } - - // Node accessor methods. Naming is designed to be similar to natural - // language as these functions tend to be used in sequences, for example: - // ParentOf(LeftestDescendentOf(RightOf(node))) - - // Returns value of the node referenced by |handle|. - Val ValueOf(uint32_t node) const { return nodes_.at(node).value; } - - // Returns left child of |node|. - uint32_t LeftOf(uint32_t node) const { return nodes_.at(node).left; } - - // Returns right child of |node|. - uint32_t RightOf(uint32_t node) const { return nodes_.at(node).right; } - - // Returns parent of |node|. - uint32_t ParentOf(uint32_t node) const { return nodes_.at(node).parent; } - - // Returns timestamp of |node|. - uint32_t TimestampOf(uint32_t node) const { - assert(node); - return nodes_.at(node).timestamp; - } - - // Returns size of |node|. - uint32_t SizeOf(uint32_t node) const { return nodes_.at(node).size; } - - // Returns height of |node|. - uint32_t HeightOf(uint32_t node) const { return nodes_.at(node).height; } - - // Returns mutable reference to value of |node|. - Val& MutableValueOf(uint32_t node) { - assert(node); - return nodes_.at(node).value; - } - - // Returns mutable reference to handle of left child of |node|. - uint32_t& MutableLeftOf(uint32_t node) { - assert(node); - return nodes_.at(node).left; - } - - // Returns mutable reference to handle of right child of |node|. - uint32_t& MutableRightOf(uint32_t node) { - assert(node); - return nodes_.at(node).right; - } - - // Returns mutable reference to handle of parent of |node|. - uint32_t& MutableParentOf(uint32_t node) { - assert(node); - return nodes_.at(node).parent; - } - - // Returns mutable reference to timestamp of |node|. - uint32_t& MutableTimestampOf(uint32_t node) { - assert(node); - return nodes_.at(node).timestamp; - } - - // Returns mutable reference to size of |node|. - uint32_t& MutableSizeOf(uint32_t node) { - assert(node); - return nodes_.at(node).size; - } - - // Returns mutable reference to height of |node|. - uint32_t& MutableHeightOf(uint32_t node) { - assert(node); - return nodes_.at(node).height; - } - - // Returns true iff |node| is left child of its parent. - bool IsLeftChild(uint32_t node) const { - assert(node); - return LeftOf(ParentOf(node)) == node; - } - - // Returns true iff |node| is right child of its parent. - bool IsRightChild(uint32_t node) const { - assert(node); - return RightOf(ParentOf(node)) == node; - } - - // Returns true iff |node| has no relatives. - bool IsOrphan(uint32_t node) const { - assert(node); - return !ParentOf(node) && !LeftOf(node) && !RightOf(node); - } - - // Returns true iff |node| is in the tree. - bool IsInTree(uint32_t node) const { - assert(node); - return node == root_ || !IsOrphan(node); - } - - // Returns the height difference between right and left subtrees. - int BalanceOf(uint32_t node) const { - return int(HeightOf(RightOf(node))) - int(HeightOf(LeftOf(node))); - } - - // Updates size and height of the node, assuming that the children have - // correct values. - void UpdateNode(uint32_t node); - - // Returns the most LeftOf(LeftOf(... descendent which is not leaf. - uint32_t LeftestDescendantOf(uint32_t node) const { - uint32_t parent = 0; - while (node) { - parent = node; - node = LeftOf(node); - } - return parent; - } - - // Returns the most RightOf(RightOf(... descendent which is not leaf. - uint32_t RightestDescendantOf(uint32_t node) const { - uint32_t parent = 0; - while (node) { - parent = node; - node = RightOf(node); - } - return parent; - } - - // Inserts node in the tree. The node must be an orphan. - void InsertNode(uint32_t node); - - // Removes node from the tree. May change value_to_node_ if removal uses a - // scapegoat. Returns the removed (orphaned) handle for recycling. The - // returned handle may not be equal to |node| if scapegoat was used. - uint32_t RemoveNode(uint32_t node); - - // Rotates |node| left, reassigns all connections and returns the node - // which takes place of the |node|. - uint32_t RotateLeft(const uint32_t node); - - // Rotates |node| right, reassigns all connections and returns the node - // which takes place of the |node|. - uint32_t RotateRight(const uint32_t node); - - // Root node handle. The tree is empty if root_ is 0. - uint32_t root_ = 0; - - // Incremented counters for next timestamp and value. - uint32_t next_timestamp_ = 1; - - // Holds all tree nodes. Indices of this vector are node handles. - std::vector nodes_; - - // Maps ids to node handles. - std::unordered_map value_to_node_; - - // Cache for the last accessed value in the sequence. - Val last_accessed_value_ = Val(); - bool last_accessed_value_valid_ = false; -}; - -template -class MultiMoveToFront { - public: - // Inserts |value| to sequence with handle |mtf|. - // Returns false if |mtf| already has |value|. - bool Insert(uint64_t mtf, const Val& value) { - if (GetMtf(mtf).Insert(value)) { - val_to_mtfs_[value].insert(mtf); - return true; - } - return false; - } - - // Removes |value| from sequence with handle |mtf|. - // Returns false if |mtf| doesn't have |value|. - bool Remove(uint64_t mtf, const Val& value) { - if (GetMtf(mtf).Remove(value)) { - val_to_mtfs_[value].erase(mtf); - return true; - } - assert(val_to_mtfs_[value].count(mtf) == 0); - return false; - } - - // Removes |value| from all sequences which have it. - void RemoveFromAll(const Val& value) { - auto it = val_to_mtfs_.find(value); - if (it == val_to_mtfs_.end()) return; - - auto& mtfs_containing_value = it->second; - for (uint64_t mtf : mtfs_containing_value) { - GetMtf(mtf).Remove(value); - } - - val_to_mtfs_.erase(value); - } - - // Computes rank of |value| in sequence |mtf|. - // Returns false if |mtf| doesn't have |value|. - bool RankFromValue(uint64_t mtf, const Val& value, uint32_t* rank) { - return GetMtf(mtf).RankFromValue(value, rank); - } - - // Finds |value| with |rank| in sequence |mtf|. - // Returns false if |rank| is out of bounds. - bool ValueFromRank(uint64_t mtf, uint32_t rank, Val* value) { - return GetMtf(mtf).ValueFromRank(rank, value); - } - - // Returns size of |mtf| sequence. - uint32_t GetSize(uint64_t mtf) { return GetMtf(mtf).GetSize(); } - - // Promotes |value| in all sequences which have it. - void Promote(const Val& value) { - const auto it = val_to_mtfs_.find(value); - if (it == val_to_mtfs_.end()) return; - - const auto& mtfs_containing_value = it->second; - for (uint64_t mtf : mtfs_containing_value) { - GetMtf(mtf).Promote(value); - } - } - - // Inserts |value| in sequence |mtf| or promotes if it's already there. - void InsertOrPromote(uint64_t mtf, const Val& value) { - if (!Insert(mtf, value)) { - GetMtf(mtf).Promote(value); - } - } - - // Returns if |mtf| sequence has |value|. - bool HasValue(uint64_t mtf, const Val& value) { - return GetMtf(mtf).HasValue(value); - } - - private: - // Returns actual MoveToFront object corresponding to |handle|. - // As multiple operations are often performed consecutively for the same - // sequence, the last returned value is cached. - MoveToFront& GetMtf(uint64_t handle) { - if (!cached_mtf_ || cached_handle_ != handle) { - cached_handle_ = handle; - cached_mtf_ = &mtfs_[handle]; - } - - return *cached_mtf_; - } - - // Container holding MoveToFront objects. Map key is sequence handle. - std::map> mtfs_; - - // Container mapping value to sequences which contain that value. - std::unordered_map> val_to_mtfs_; - - // Cache for the last accessed sequence. - uint64_t cached_handle_ = 0; - MoveToFront* cached_mtf_ = nullptr; -}; - -template -bool MoveToFront::Insert(const Val& value) { - auto it = value_to_node_.find(value); - if (it != value_to_node_.end() && IsInTree(it->second)) return false; - - const uint32_t old_size = GetSize(); - (void)old_size; - - InsertNode(CreateNode(next_timestamp_++, value)); - - last_accessed_value_ = value; - last_accessed_value_valid_ = true; - - assert(value_to_node_.count(value)); - assert(old_size + 1 == GetSize()); - return true; -} - -template -bool MoveToFront::Remove(const Val& value) { - auto it = value_to_node_.find(value); - if (it == value_to_node_.end()) return false; - - if (!IsInTree(it->second)) return false; - - if (last_accessed_value_ == value) last_accessed_value_valid_ = false; - - const uint32_t orphan = RemoveNode(it->second); - (void)orphan; - // The node of |value| is still alive but it's orphaned now. Can still be - // reused later. - assert(!IsInTree(orphan)); - assert(ValueOf(orphan) == value); - return true; -} - -template -bool MoveToFront::RankFromValue(const Val& value, uint32_t* rank) { - if (last_accessed_value_valid_ && last_accessed_value_ == value) { - *rank = 1; - return true; - } - - const uint32_t old_size = GetSize(); - if (old_size == 1) { - if (ValueOf(root_) == value) { - *rank = 1; - return true; - } else { - return false; - } - } - - const auto it = value_to_node_.find(value); - if (it == value_to_node_.end()) { - return false; - } - - uint32_t target = it->second; - - if (!IsInTree(target)) { - return false; - } - - uint32_t node = target; - *rank = 1 + SizeOf(LeftOf(node)); - while (node) { - if (IsRightChild(node)) *rank += 1 + SizeOf(LeftOf(ParentOf(node))); - node = ParentOf(node); - } - - // Don't update timestamp if the node has rank 1. - if (*rank != 1) { - // Update timestamp and reposition the node. - target = RemoveNode(target); - assert(ValueOf(target) == value); - assert(old_size == GetSize() + 1); - MutableTimestampOf(target) = next_timestamp_++; - InsertNode(target); - assert(old_size == GetSize()); - } - - last_accessed_value_ = value; - last_accessed_value_valid_ = true; - return true; -} - -template -bool MoveToFront::HasValue(const Val& value) const { - const auto it = value_to_node_.find(value); - if (it == value_to_node_.end()) { - return false; - } - - return IsInTree(it->second); -} - -template -bool MoveToFront::Promote(const Val& value) { - if (last_accessed_value_valid_ && last_accessed_value_ == value) { - return true; - } - - const uint32_t old_size = GetSize(); - if (old_size == 1) return ValueOf(root_) == value; - - const auto it = value_to_node_.find(value); - if (it == value_to_node_.end()) { - return false; - } - - uint32_t target = it->second; - - if (!IsInTree(target)) { - return false; - } - - // Update timestamp and reposition the node. - target = RemoveNode(target); - assert(ValueOf(target) == value); - assert(old_size == GetSize() + 1); - MutableTimestampOf(target) = next_timestamp_++; - InsertNode(target); - assert(old_size == GetSize()); - - last_accessed_value_ = value; - last_accessed_value_valid_ = true; - return true; -} - -template -bool MoveToFront::ValueFromRank(uint32_t rank, Val* value) { - if (last_accessed_value_valid_ && rank == 1) { - *value = last_accessed_value_; - return true; - } - - const uint32_t old_size = GetSize(); - if (rank <= 0 || rank > old_size) { - return false; - } - - if (old_size == 1) { - *value = ValueOf(root_); - return true; - } - - const bool update_timestamp = (rank != 1); - - uint32_t node = root_; - while (node) { - const uint32_t left_subtree_num_nodes = SizeOf(LeftOf(node)); - if (rank == left_subtree_num_nodes + 1) { - // This is the node we are looking for. - // Don't update timestamp if the node has rank 1. - if (update_timestamp) { - node = RemoveNode(node); - assert(old_size == GetSize() + 1); - MutableTimestampOf(node) = next_timestamp_++; - InsertNode(node); - assert(old_size == GetSize()); - } - *value = ValueOf(node); - last_accessed_value_ = *value; - last_accessed_value_valid_ = true; - return true; - } - - if (rank < left_subtree_num_nodes + 1) { - // Descend into the left subtree. The rank is still valid. - node = LeftOf(node); - } else { - // Descend into the right subtree. We leave behind the left subtree and - // the current node, adjust the |rank| accordingly. - rank -= left_subtree_num_nodes + 1; - node = RightOf(node); - } - } - - assert(0); - return false; -} - -template -void MoveToFront::InsertNode(uint32_t node) { - assert(!IsInTree(node)); - assert(SizeOf(node) == 1); - assert(HeightOf(node) == 1); - assert(TimestampOf(node)); - - if (!root_) { - root_ = node; - return; - } - - uint32_t iter = root_; - uint32_t parent = 0; - - // Will determine if |node| will become the right or left child after - // insertion (but before balancing). - bool right_child = true; - - // Find the node which will become |node|'s parent after insertion - // (but before balancing). - while (iter) { - parent = iter; - assert(TimestampOf(iter) != TimestampOf(node)); - right_child = TimestampOf(iter) > TimestampOf(node); - iter = right_child ? RightOf(iter) : LeftOf(iter); - } - - assert(parent); - - // Connect node and parent. - MutableParentOf(node) = parent; - if (right_child) - MutableRightOf(parent) = node; - else - MutableLeftOf(parent) = node; - - // Insertion is finished. Start the balancing process. - bool needs_rebalancing = true; - parent = ParentOf(node); - - while (parent) { - UpdateNode(parent); - - if (needs_rebalancing) { - const int parent_balance = BalanceOf(parent); - - if (RightOf(parent) == node) { - // Added node to the right subtree. - if (parent_balance > 1) { - // Parent is right heavy, rotate left. - if (BalanceOf(node) < 0) RotateRight(node); - parent = RotateLeft(parent); - } else if (parent_balance == 0 || parent_balance == -1) { - // Parent is balanced or left heavy, no need to balance further. - needs_rebalancing = false; - } - } else { - // Added node to the left subtree. - if (parent_balance < -1) { - // Parent is left heavy, rotate right. - if (BalanceOf(node) > 0) RotateLeft(node); - parent = RotateRight(parent); - } else if (parent_balance == 0 || parent_balance == 1) { - // Parent is balanced or right heavy, no need to balance further. - needs_rebalancing = false; - } - } - } - - assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1)); - - node = parent; - parent = ParentOf(parent); - } -} - -template -uint32_t MoveToFront::RemoveNode(uint32_t node) { - if (LeftOf(node) && RightOf(node)) { - // If |node| has two children, then use another node as scapegoat and swap - // their contents. We pick the scapegoat on the side of the tree which has - // more nodes. - const uint32_t scapegoat = SizeOf(LeftOf(node)) >= SizeOf(RightOf(node)) - ? RightestDescendantOf(LeftOf(node)) - : LeftestDescendantOf(RightOf(node)); - assert(scapegoat); - std::swap(MutableValueOf(node), MutableValueOf(scapegoat)); - std::swap(MutableTimestampOf(node), MutableTimestampOf(scapegoat)); - value_to_node_[ValueOf(node)] = node; - value_to_node_[ValueOf(scapegoat)] = scapegoat; - node = scapegoat; - } - - // |node| may have only one child at this point. - assert(!RightOf(node) || !LeftOf(node)); - - uint32_t parent = ParentOf(node); - uint32_t child = RightOf(node) ? RightOf(node) : LeftOf(node); - - // Orphan |node| and reconnect parent and child. - if (child) MutableParentOf(child) = parent; - - if (parent) { - if (LeftOf(parent) == node) - MutableLeftOf(parent) = child; - else - MutableRightOf(parent) = child; - } - - MutableParentOf(node) = 0; - MutableLeftOf(node) = 0; - MutableRightOf(node) = 0; - UpdateNode(node); - const uint32_t orphan = node; - - if (root_ == node) root_ = child; - - // Removal is finished. Start the balancing process. - bool needs_rebalancing = true; - node = child; - - while (parent) { - UpdateNode(parent); - - if (needs_rebalancing) { - const int parent_balance = BalanceOf(parent); - - if (parent_balance == 1 || parent_balance == -1) { - // The height of the subtree was not changed. - needs_rebalancing = false; - } else { - if (RightOf(parent) == node) { - // Removed node from the right subtree. - if (parent_balance < -1) { - // Parent is left heavy, rotate right. - const uint32_t sibling = LeftOf(parent); - if (BalanceOf(sibling) > 0) RotateLeft(sibling); - parent = RotateRight(parent); - } - } else { - // Removed node from the left subtree. - if (parent_balance > 1) { - // Parent is right heavy, rotate left. - const uint32_t sibling = RightOf(parent); - if (BalanceOf(sibling) < 0) RotateRight(sibling); - parent = RotateLeft(parent); - } - } - } - } - - assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1)); - - node = parent; - parent = ParentOf(parent); - } - - return orphan; -} - -template -uint32_t MoveToFront::RotateLeft(const uint32_t node) { - const uint32_t pivot = RightOf(node); - assert(pivot); - - // LeftOf(pivot) gets attached to node in place of pivot. - MutableRightOf(node) = LeftOf(pivot); - if (RightOf(node)) MutableParentOf(RightOf(node)) = node; - - // Pivot gets attached to ParentOf(node) in place of node. - MutableParentOf(pivot) = ParentOf(node); - if (!ParentOf(node)) - root_ = pivot; - else if (IsLeftChild(node)) - MutableLeftOf(ParentOf(node)) = pivot; - else - MutableRightOf(ParentOf(node)) = pivot; - - // Node is child of pivot. - MutableLeftOf(pivot) = node; - MutableParentOf(node) = pivot; - - // Update both node and pivot. Pivot is the new parent of node, so node should - // be updated first. - UpdateNode(node); - UpdateNode(pivot); - - return pivot; -} - -template -uint32_t MoveToFront::RotateRight(const uint32_t node) { - const uint32_t pivot = LeftOf(node); - assert(pivot); - - // RightOf(pivot) gets attached to node in place of pivot. - MutableLeftOf(node) = RightOf(pivot); - if (LeftOf(node)) MutableParentOf(LeftOf(node)) = node; - - // Pivot gets attached to ParentOf(node) in place of node. - MutableParentOf(pivot) = ParentOf(node); - if (!ParentOf(node)) - root_ = pivot; - else if (IsLeftChild(node)) - MutableLeftOf(ParentOf(node)) = pivot; - else - MutableRightOf(ParentOf(node)) = pivot; - - // Node is child of pivot. - MutableRightOf(pivot) = node; - MutableParentOf(node) = pivot; - - // Update both node and pivot. Pivot is the new parent of node, so node should - // be updated first. - UpdateNode(node); - UpdateNode(pivot); - - return pivot; -} - -template -void MoveToFront::UpdateNode(uint32_t node) { - MutableSizeOf(node) = 1 + SizeOf(LeftOf(node)) + SizeOf(RightOf(node)); - MutableHeightOf(node) = - 1 + std::max(HeightOf(LeftOf(node)), HeightOf(RightOf(node))); -} - -} // namespace spvutils - -#endif // LIBSPIRV_UTIL_MOVE_TO_FRONT_H_ diff --git a/3rdparty/spirv-tools/source/util/parse_number.cpp b/3rdparty/spirv-tools/source/util/parse_number.cpp index bb87b3dad..c3351c236 100644 --- a/3rdparty/spirv-tools/source/util/parse_number.cpp +++ b/3rdparty/spirv-tools/source/util/parse_number.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "util/parse_number.h" +#include "source/util/parse_number.h" #include #include @@ -21,11 +21,13 @@ #include #include -#include "util/hex_float.h" - -namespace spvutils { +#include "source/util/hex_float.h" +#include "source/util/make_unique.h" +namespace spvtools { +namespace utils { namespace { + // A helper class that temporarily stores error messages and dump the messages // to a string which given as as pointer when it is destructed. If the given // pointer is a nullptr, this class does not store error message. @@ -33,7 +35,7 @@ class ErrorMsgStream { public: explicit ErrorMsgStream(std::string* error_msg_sink) : error_msg_sink_(error_msg_sink) { - if (error_msg_sink_) stream_.reset(new std::ostringstream()); + if (error_msg_sink_) stream_ = MakeUnique(); } ~ErrorMsgStream() { if (error_msg_sink_ && stream_) *error_msg_sink_ = stream_->str(); @@ -145,12 +147,12 @@ EncodeNumberStatus ParseAndEncodeFloatingPointNumber( const auto bit_width = AssumedBitWidth(type); switch (bit_width) { case 16: { - HexFloat> hVal(0); + HexFloat> hVal(0); if (!ParseNumber(text, &hVal)) { ErrorMsgStream(error_msg) << "Invalid 16-bit float literal: " << text; return EncodeNumberStatus::kInvalidText; } - // getAsFloat will return the spvutils::Float16 value, and get_value + // getAsFloat will return the Float16 value, and get_value // will return a uint16_t representing the bits of the float. // The encoding is therefore correct from the perspective of the SPIR-V // spec since the top 16 bits will be 0. @@ -211,4 +213,5 @@ EncodeNumberStatus ParseAndEncodeNumber(const char* text, return ParseAndEncodeIntegerNumber(text, type, emit, error_msg); } -} // namespace spvutils +} // namespace utils +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/util/parse_number.h b/3rdparty/spirv-tools/source/util/parse_number.h index 2a9bd6d46..729aac54b 100644 --- a/3rdparty/spirv-tools/source/util/parse_number.h +++ b/3rdparty/spirv-tools/source/util/parse_number.h @@ -12,17 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_UTIL_PARSE_NUMBER_H_ -#define LIBSPIRV_UTIL_PARSE_NUMBER_H_ +#ifndef SOURCE_UTIL_PARSE_NUMBER_H_ +#define SOURCE_UTIL_PARSE_NUMBER_H_ #include #include #include +#include "source/util/hex_float.h" #include "spirv-tools/libspirv.h" -#include "util/hex_float.h" -namespace spvutils { +namespace spvtools { +namespace utils { // A struct to hold the expected type information for the number in text to be // parsed. @@ -245,6 +246,7 @@ EncodeNumberStatus ParseAndEncodeNumber(const char* text, std::function emit, std::string* error_msg); -} // namespace spvutils +} // namespace utils +} // namespace spvtools -#endif // LIBSPIRV_UTIL_PARSE_NUMBER_H_ +#endif // SOURCE_UTIL_PARSE_NUMBER_H_ diff --git a/3rdparty/spirv-tools/source/util/small_vector.h b/3rdparty/spirv-tools/source/util/small_vector.h new file mode 100644 index 000000000..f2c1147be --- /dev/null +++ b/3rdparty/spirv-tools/source/util/small_vector.h @@ -0,0 +1,466 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_UTIL_SMALL_VECTOR_H_ +#define SOURCE_UTIL_SMALL_VECTOR_H_ + +#include +#include +#include +#include +#include + +#include "source/util/make_unique.h" + +namespace spvtools { +namespace utils { + +// The |SmallVector| class is intended to be a drop-in replacement for +// |std::vector|. The difference is in the implementation. A |SmallVector| is +// optimized for when the number of elements in the vector are small. Small is +// defined by the template parameter |small_size|. +// +// Note that |SmallVector| is not always faster than an |std::vector|, so you +// should experiment with different values for |small_size| and compare to +// using and |std::vector|. +// +// TODO: I have implemented the public member functions from |std::vector| that +// I needed. If others are needed they should be implemented. Do not implement +// public member functions that are not defined by std::vector. +template +class SmallVector { + public: + using iterator = T*; + using const_iterator = const T*; + + SmallVector() + : size_(0), + small_data_(reinterpret_cast(buffer)), + large_data_(nullptr) {} + + SmallVector(const SmallVector& that) : SmallVector() { *this = that; } + + SmallVector(SmallVector&& that) : SmallVector() { *this = std::move(that); } + + SmallVector(const std::vector& vec) : SmallVector() { + if (vec.size() > small_size) { + large_data_ = MakeUnique>(vec); + } else { + size_ = vec.size(); + for (uint32_t i = 0; i < size_; i++) { + new (small_data_ + i) T(vec[i]); + } + } + } + + SmallVector(std::vector&& vec) : SmallVector() { + if (vec.size() > small_size) { + large_data_ = MakeUnique>(std::move(vec)); + } else { + size_ = vec.size(); + for (uint32_t i = 0; i < size_; i++) { + new (small_data_ + i) T(std::move(vec[i])); + } + } + vec.clear(); + } + + SmallVector(std::initializer_list init_list) : SmallVector() { + if (init_list.size() < small_size) { + for (auto it = init_list.begin(); it != init_list.end(); ++it) { + new (small_data_ + (size_++)) T(std::move(*it)); + } + } else { + large_data_ = MakeUnique>(std::move(init_list)); + } + } + + SmallVector(size_t s, const T& v) : SmallVector() { resize(s, v); } + + virtual ~SmallVector() { + for (T* p = small_data_; p < small_data_ + size_; ++p) { + p->~T(); + } + } + + SmallVector& operator=(const SmallVector& that) { + assert(small_data_); + if (that.large_data_) { + if (large_data_) { + *large_data_ = *that.large_data_; + } else { + large_data_ = MakeUnique>(*that.large_data_); + } + } else { + large_data_.reset(nullptr); + size_t i = 0; + // Do a copy for any element in |this| that is already constructed. + for (; i < size_ && i < that.size_; ++i) { + small_data_[i] = that.small_data_[i]; + } + + if (i >= that.size_) { + // If the size of |this| becomes smaller after the assignment, then + // destroy any extra elements. + for (; i < size_; ++i) { + small_data_[i].~T(); + } + } else { + // If the size of |this| becomes larger after the assignement, copy + // construct the new elements that are needed. + for (; i < that.size_; ++i) { + new (small_data_ + i) T(that.small_data_[i]); + } + } + size_ = that.size_; + } + return *this; + } + + SmallVector& operator=(SmallVector&& that) { + if (that.large_data_) { + large_data_.reset(that.large_data_.release()); + } else { + large_data_.reset(nullptr); + size_t i = 0; + // Do a move for any element in |this| that is already constructed. + for (; i < size_ && i < that.size_; ++i) { + small_data_[i] = std::move(that.small_data_[i]); + } + + if (i >= that.size_) { + // If the size of |this| becomes smaller after the assignment, then + // destroy any extra elements. + for (; i < size_; ++i) { + small_data_[i].~T(); + } + } else { + // If the size of |this| becomes larger after the assignement, move + // construct the new elements that are needed. + for (; i < that.size_; ++i) { + new (small_data_ + i) T(std::move(that.small_data_[i])); + } + } + size_ = that.size_; + } + + // Reset |that| because all of the data has been moved to |this|. + that.DestructSmallData(); + return *this; + } + + template + friend bool operator==(const SmallVector& lhs, const OtherVector& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + + auto rit = rhs.begin(); + for (auto lit = lhs.begin(); lit != lhs.end(); ++lit, ++rit) { + if (*lit != *rit) { + return false; + } + } + return true; + } + + friend bool operator==(const std::vector& lhs, const SmallVector& rhs) { + return rhs == lhs; + } + + friend bool operator!=(const SmallVector& lhs, const std::vector& rhs) { + return !(lhs == rhs); + } + + friend bool operator!=(const std::vector& lhs, const SmallVector& rhs) { + return rhs != lhs; + } + + T& operator[](size_t i) { + if (!large_data_) { + return small_data_[i]; + } else { + return (*large_data_)[i]; + } + } + + const T& operator[](size_t i) const { + if (!large_data_) { + return small_data_[i]; + } else { + return (*large_data_)[i]; + } + } + + size_t size() const { + if (!large_data_) { + return size_; + } else { + return large_data_->size(); + } + } + + iterator begin() { + if (large_data_) { + return large_data_->data(); + } else { + return small_data_; + } + } + + const_iterator begin() const { + if (large_data_) { + return large_data_->data(); + } else { + return small_data_; + } + } + + const_iterator cbegin() const { return begin(); } + + iterator end() { + if (large_data_) { + return large_data_->data() + large_data_->size(); + } else { + return small_data_ + size_; + } + } + + const_iterator end() const { + if (large_data_) { + return large_data_->data() + large_data_->size(); + } else { + return small_data_ + size_; + } + } + + const_iterator cend() const { return end(); } + + T* data() { return begin(); } + + const T* data() const { return cbegin(); } + + T& front() { return (*this)[0]; } + + const T& front() const { return (*this)[0]; } + + iterator erase(const_iterator pos) { return erase(pos, pos + 1); } + + iterator erase(const_iterator first, const_iterator last) { + if (large_data_) { + size_t start_index = first - large_data_->data(); + size_t end_index = last - large_data_->data(); + auto r = large_data_->erase(large_data_->begin() + start_index, + large_data_->begin() + end_index); + return large_data_->data() + (r - large_data_->begin()); + } + + // Since C++11, std::vector has |const_iterator| for the parameters, so I + // follow that. However, I need iterators to modify the current container, + // which is not const. This is why I cast away the const. + iterator f = const_cast(first); + iterator l = const_cast(last); + iterator e = end(); + + size_t num_of_del_elements = last - first; + iterator ret = f; + if (first == last) { + return ret; + } + + // Move |last| and any elements after it their earlier position. + while (l != e) { + *f = std::move(*l); + ++f; + ++l; + } + + // Destroy the elements that were supposed to be deleted. + while (f != l) { + f->~T(); + ++f; + } + + // Update the size. + size_ -= num_of_del_elements; + return ret; + } + + void push_back(const T& value) { + if (!large_data_ && size_ == small_size) { + MoveToLargeData(); + } + + if (large_data_) { + large_data_->push_back(value); + return; + } + + new (small_data_ + size_) T(value); + ++size_; + } + + void push_back(T&& value) { + if (!large_data_ && size_ == small_size) { + MoveToLargeData(); + } + + if (large_data_) { + large_data_->push_back(std::move(value)); + return; + } + + new (small_data_ + size_) T(std::move(value)); + ++size_; + } + + template + iterator insert(iterator pos, InputIt first, InputIt last) { + size_t element_idx = (pos - begin()); + size_t num_of_new_elements = std::distance(first, last); + size_t new_size = size_ + num_of_new_elements; + if (!large_data_ && new_size > small_size) { + MoveToLargeData(); + } + + if (large_data_) { + typename std::vector::iterator new_pos = + large_data_->begin() + element_idx; + large_data_->insert(new_pos, first, last); + return begin() + element_idx; + } + + // Move |pos| and all of the elements after it over |num_of_new_elements| + // places. We start at the end and work backwards, to make sure we do not + // overwrite data that we have not moved yet. + for (iterator i = begin() + new_size - 1, j = end() - 1; j >= pos; + --i, --j) { + if (i >= begin() + size_) { + new (i) T(std::move(*j)); + } else { + *i = std::move(*j); + } + } + + // Copy the new elements into position. + iterator p = pos; + for (; first != last; ++p, ++first) { + if (p >= small_data_ + size_) { + new (p) T(*first); + } else { + *p = *first; + } + } + + // Upate the size. + size_ += num_of_new_elements; + return pos; + } + + bool empty() const { + if (large_data_) { + return large_data_->empty(); + } + return size_ == 0; + } + + void clear() { + if (large_data_) { + large_data_->clear(); + } else { + DestructSmallData(); + } + } + + template + void emplace_back(Args&&... args) { + if (!large_data_ && size_ == small_size) { + MoveToLargeData(); + } + + if (large_data_) { + large_data_->emplace_back(std::forward(args)...); + } else { + new (small_data_ + size_) T(std::forward(args)...); + ++size_; + } + } + + void resize(size_t new_size, const T& v) { + if (!large_data_ && new_size > small_size) { + MoveToLargeData(); + } + + if (large_data_) { + large_data_->resize(new_size, v); + return; + } + + // If |new_size| < |size_|, then destroy the extra elements. + for (size_t i = new_size; i < size_; ++i) { + small_data_[i].~T(); + } + + // If |new_size| > |size_|, the copy construct the new elements. + for (size_t i = size_; i < new_size; ++i) { + new (small_data_ + i) T(v); + } + + // Update the size. + size_ = new_size; + } + + private: + // Moves all of the element from |small_data_| into a new std::vector that can + // be access through |large_data|. + void MoveToLargeData() { + assert(!large_data_); + large_data_ = MakeUnique>(); + for (size_t i = 0; i < size_; ++i) { + large_data_->emplace_back(std::move(small_data_[i])); + } + DestructSmallData(); + } + + // Destroys all of the elements in |small_data_| that have been constructed. + void DestructSmallData() { + for (size_t i = 0; i < size_; ++i) { + small_data_[i].~T(); + } + size_ = 0; + } + + // The number of elements in |small_data_| that have been constructed. + size_t size_; + + // The pointed used to access the array of elements when the number of + // elements is small. + T* small_data_; + + // The actual data used to store the array elements. It must never be used + // directly, but must only be accesed through |small_data_|. + typename std::aligned_storage::value>::type + buffer[small_size]; + + // A pointer to a vector that is used to store the elements of the vector when + // this size exceeds |small_size|. If |large_data_| is nullptr, then the data + // is stored in |small_data_|. Otherwise, the data is stored in + // |large_data_|. + std::unique_ptr> large_data_; +}; // namespace utils + +} // namespace utils +} // namespace spvtools + +#endif // SOURCE_UTIL_SMALL_VECTOR_H_ diff --git a/3rdparty/spirv-tools/source/util/string_utils.cpp b/3rdparty/spirv-tools/source/util/string_utils.cpp index 830f1a3b1..29ce2aa4a 100644 --- a/3rdparty/spirv-tools/source/util/string_utils.cpp +++ b/3rdparty/spirv-tools/source/util/string_utils.cpp @@ -16,9 +16,10 @@ #include #include -#include "util/string_utils.h" +#include "source/util/string_utils.h" -namespace spvutils { +namespace spvtools { +namespace utils { std::string CardinalToOrdinal(size_t cardinal) { const size_t mod10 = cardinal % 10; @@ -36,4 +37,5 @@ std::string CardinalToOrdinal(size_t cardinal) { return ToString(cardinal) + suffix; } -} // namespace spvutils +} // namespace utils +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/util/string_utils.h b/3rdparty/spirv-tools/source/util/string_utils.h index 993b58dfc..322c574fb 100644 --- a/3rdparty/spirv-tools/source/util/string_utils.h +++ b/3rdparty/spirv-tools/source/util/string_utils.h @@ -12,21 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_UTIL_STRING_UTILS_H_ -#define LIBSPIRV_UTIL_STRING_UTILS_H_ +#ifndef SOURCE_UTIL_STRING_UTILS_H_ +#define SOURCE_UTIL_STRING_UTILS_H_ #include #include -#include "util/string_utils.h" +#include "source/util/string_utils.h" -namespace spvutils { +namespace spvtools { +namespace utils { // Converts arithmetic value |val| to its default string representation. template std::string ToString(T val) { - static_assert(std::is_arithmetic::value, - "spvutils::ToString is restricted to only arithmetic values"); + static_assert( + std::is_arithmetic::value, + "spvtools::utils::ToString is restricted to only arithmetic values"); std::stringstream os; os << val; return os.str(); @@ -35,6 +37,7 @@ std::string ToString(T val) { // Converts cardinal number to ordinal number string. std::string CardinalToOrdinal(size_t cardinal); -} // namespace spvutils +} // namespace utils +} // namespace spvtools -#endif // LIBSPIRV_UTIL_STRING_UTILS_H_ +#endif // SOURCE_UTIL_STRING_UTILS_H_ diff --git a/3rdparty/spirv-tools/source/util/timer.cpp b/3rdparty/spirv-tools/source/util/timer.cpp index 722bc6c41..c8b8d5b61 100644 --- a/3rdparty/spirv-tools/source/util/timer.cpp +++ b/3rdparty/spirv-tools/source/util/timer.cpp @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "util/timer.h" +#if defined(SPIRV_TIMER_ENABLED) + +#include "source/util/timer.h" #include #include @@ -20,9 +22,8 @@ #include #include -namespace spvutils { - -#if defined(SPIRV_TIMER_ENABLED) +namespace spvtools { +namespace utils { void PrintTimerDescription(std::ostream* out, bool measure_mem_usage) { if (out) { @@ -95,6 +96,7 @@ void Timer::Report(const char* tag) { *report_stream_ << std::endl; } -#endif // defined(SPIRV_TIMER_ENABLED) +} // namespace utils +} // namespace spvtools -} // namespace spvutils +#endif // defined(SPIRV_TIMER_ENABLED) diff --git a/3rdparty/spirv-tools/source/util/timer.h b/3rdparty/spirv-tools/source/util/timer.h index c6af24e7a..fc4b747b9 100644 --- a/3rdparty/spirv-tools/source/util/timer.h +++ b/3rdparty/spirv-tools/source/util/timer.h @@ -14,8 +14,8 @@ // Contains utils for getting resource utilization -#ifndef LIBSPIRV_UTIL_TIMER_H_ -#define LIBSPIRV_UTIL_TIMER_H_ +#ifndef SOURCE_UTIL_TIMER_H_ +#define SOURCE_UTIL_TIMER_H_ #if defined(SPIRV_TIMER_ENABLED) @@ -23,16 +23,17 @@ #include #include -// A macro to call spvutils::PrintTimerDescription(std::ostream*, bool). The -// first argument must be given as std::ostream*. If it is NULL, the function -// does nothing. Otherwise, it prints resource types measured by Timer class. -// The second is optional and if it is true, the function also prints resource -// type fields related to memory. Otherwise, it does not print memory related -// fields. Its default is false. In usual, this must be placed before calling -// Timer::Report() to inform what those fields printed by Timer::Report() -// indicate (or spvutils::PrintTimerDescription() must be used instead). +// A macro to call spvtools::utils::PrintTimerDescription(std::ostream*, bool). +// The first argument must be given as std::ostream*. If it is NULL, the +// function does nothing. Otherwise, it prints resource types measured by Timer +// class. The second is optional and if it is true, the function also prints +// resource type fields related to memory. Otherwise, it does not print memory +// related fields. Its default is false. In usual, this must be placed before +// calling Timer::Report() to inform what those fields printed by +// Timer::Report() indicate (or spvtools::utils::PrintTimerDescription() must be +// used instead). #define SPIRV_TIMER_DESCRIPTION(...) \ - spvutils::PrintTimerDescription(__VA_ARGS__) + spvtools::utils::PrintTimerDescription(__VA_ARGS__) // Creates an object of ScopedTimer to measure the resource utilization for the // scope surrounding it as the following example: @@ -47,10 +48,12 @@ // // } // <-- end of this scope. The destructor of ScopedTimer prints tag and // the resource utilization to std::cout. -#define SPIRV_TIMER_SCOPED(...) \ - spvutils::ScopedTimer timer##__LINE__(__VA_ARGS__) +#define SPIRV_TIMER_SCOPED(...) \ + spvtools::utils::ScopedTimer timer##__LINE__( \ + __VA_ARGS__) -namespace spvutils { +namespace spvtools { +namespace utils { // Prints the description of resource types measured by Timer class. If |out| is // NULL, it does nothing. Otherwise, it prints resource types. The second is @@ -78,7 +81,7 @@ enum UsageStatus { // only when |measure_mem_usage| given to the constructor is true. This class // should be used as the following example: // -// spvutils::Timer timer(std::cout); +// spvtools::utils::Timer timer(std::cout); // timer.Start(); // <-- set |usage_before_|, |wall_before_|, // and |cpu_before_| // @@ -232,7 +235,8 @@ class Timer { // // /* ... code out of interest ... */ // -// spvutils::ScopedTimer scopedtimer(std::cout, tag); +// spvtools::utils::ScopedTimer +// scopedtimer(std::cout, tag); // // /* ... lines of code that we want to know its resource usage ... */ // @@ -375,7 +379,8 @@ class CumulativeTimer : public Timer { long pgfaults_; }; -} // namespace spvutils +} // namespace utils +} // namespace spvtools #else // defined(SPIRV_TIMER_ENABLED) @@ -384,4 +389,4 @@ class CumulativeTimer : public Timer { #endif // defined(SPIRV_TIMER_ENABLED) -#endif // LIBSPIRV_UTIL_TIMER_H_ +#endif // SOURCE_UTIL_TIMER_H_ diff --git a/3rdparty/spirv-tools/source/val/basic_block.cpp b/3rdparty/spirv-tools/source/val/basic_block.cpp index a0b10fa71..a53103c8a 100644 --- a/3rdparty/spirv-tools/source/val/basic_block.cpp +++ b/3rdparty/spirv-tools/source/val/basic_block.cpp @@ -12,15 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "val/basic_block.h" +#include "source/val/basic_block.h" #include #include #include -using std::vector; - -namespace libspirv { +namespace spvtools { +namespace val { BasicBlock::BasicBlock(uint32_t label_id) : id_(label_id), @@ -29,7 +28,9 @@ BasicBlock::BasicBlock(uint32_t label_id) predecessors_(), successors_(), type_(0), - reachable_(false) {} + reachable_(false), + label_(nullptr), + terminator_(nullptr) {} void BasicBlock::SetImmediateDominator(BasicBlock* dom_block) { immediate_dominator_ = dom_block; @@ -52,7 +53,8 @@ BasicBlock* BasicBlock::immediate_post_dominator() { return immediate_post_dominator_; } -void BasicBlock::RegisterSuccessors(const vector& next_blocks) { +void BasicBlock::RegisterSuccessors( + const std::vector& next_blocks) { for (auto& block : next_blocks) { block->predecessors_.push_back(this); successors_.push_back(block); @@ -142,4 +144,6 @@ bool operator!=(const BasicBlock::DominatorIterator& lhs, const BasicBlock*& BasicBlock::DominatorIterator::operator*() { return current_; } -} // namespace libspirv + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/basic_block.h b/3rdparty/spirv-tools/source/val/basic_block.h index c2a5bb8fd..efbd243b6 100644 --- a/3rdparty/spirv-tools/source/val/basic_block.h +++ b/3rdparty/spirv-tools/source/val/basic_block.h @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_VAL_BASICBLOCK_H_ -#define LIBSPIRV_VAL_BASICBLOCK_H_ - -#include "latest_version_spirv_header.h" +#ifndef SOURCE_VAL_BASIC_BLOCK_H_ +#define SOURCE_VAL_BASIC_BLOCK_H_ #include - #include #include #include #include -namespace libspirv { +#include "source/latest_version_spirv_header.h" + +namespace spvtools { +namespace val { enum BlockType : uint32_t { kBlockTypeUndefined, @@ -37,6 +37,8 @@ enum BlockType : uint32_t { kBlockTypeCOUNT ///< Total number of block types. (must be the last element) }; +class Instruction; + // This class represents a basic block in a SPIR-V module class BasicBlock { public: @@ -107,6 +109,18 @@ class BasicBlock { /// Ends the block without a successor void RegisterBranchInstruction(SpvOp branch_instruction); + /// Returns the label instruction for the block, or nullptr if not set. + const Instruction* label() const { return label_; } + + //// Registers the label instruction for the block. + void set_label(const Instruction* t) { label_ = t; } + + /// Registers the terminator instruction for the block. + void set_terminator(const Instruction* t) { terminator_ = t; } + + /// Returns the terminator instruction for the block. + const Instruction* terminator() const { return terminator_; } + /// Adds @p next BasicBlocks as successors of this BasicBlock void RegisterSuccessors( const std::vector& next = std::vector()); @@ -205,10 +219,16 @@ class BasicBlock { std::vector successors_; /// The type of the block - std::bitset type_; + std::bitset type_; /// True if the block is reachable in the CFG bool reachable_; + + /// label of this block, if any. + const Instruction* label_; + + /// Terminator of this block. + const Instruction* terminator_; }; /// @brief Returns true if the iterators point to the same element or if both @@ -221,6 +241,7 @@ bool operator==(const BasicBlock::DominatorIterator& lhs, bool operator!=(const BasicBlock::DominatorIterator& lhs, const BasicBlock::DominatorIterator& rhs); -} // namespace libspirv +} // namespace val +} // namespace spvtools -#endif /// LIBSPIRV_VAL_BASICBLOCK_H_ +#endif // SOURCE_VAL_BASIC_BLOCK_H_ diff --git a/3rdparty/spirv-tools/source/val/construct.cpp b/3rdparty/spirv-tools/source/val/construct.cpp index c5f01dfb6..c11a065b7 100644 --- a/3rdparty/spirv-tools/source/val/construct.cpp +++ b/3rdparty/spirv-tools/source/val/construct.cpp @@ -12,12 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "val/construct.h" +#include "source/val/construct.h" #include #include +#include -namespace libspirv { +#include "source/val/function.h" + +namespace spvtools { +namespace val { Construct::Construct(ConstructType construct_type, BasicBlock* entry, BasicBlock* exit, std::vector constructs) @@ -64,4 +68,61 @@ const BasicBlock* Construct::exit_block() const { return exit_block_; } BasicBlock* Construct::exit_block() { return exit_block_; } void Construct::set_exit(BasicBlock* block) { exit_block_ = block; } -} // namespace libspirv + +Construct::ConstructBlockSet Construct::blocks(Function* function) const { + auto header = entry_block(); + auto merge = exit_block(); + assert(header); + assert(merge); + int header_depth = function->GetBlockDepth(const_cast(header)); + ConstructBlockSet construct_blocks; + std::unordered_set corresponding_headers; + for (auto& other : corresponding_constructs()) { + corresponding_headers.insert(other->entry_block()); + } + std::vector stack; + stack.push_back(const_cast(header)); + while (!stack.empty()) { + BasicBlock* block = stack.back(); + stack.pop_back(); + + if (merge == block && ExitBlockIsMergeBlock()) { + // Merge block is not part of the construct. + continue; + } + + if (corresponding_headers.count(block)) { + // Entered a corresponding construct. + continue; + } + + int block_depth = function->GetBlockDepth(block); + if (block_depth < header_depth) { + // Broke to outer construct. + continue; + } + + // In a loop, the continue target is at a depth of the loop construct + 1. + // A selection construct nested directly within the loop construct is also + // at the same depth. It is valid, however, to branch directly to the + // continue target from within the selection construct. + if (block_depth == header_depth && type() == ConstructType::kSelection && + block->is_type(kBlockTypeContinue)) { + // Continued to outer construct. + continue; + } + + if (!construct_blocks.insert(block).second) continue; + + if (merge != block) { + for (auto succ : *block->successors()) { + stack.push_back(succ); + } + } + } + + return construct_blocks; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/construct.h b/3rdparty/spirv-tools/source/val/construct.h index 594d8d14b..c7e7a780d 100644 --- a/3rdparty/spirv-tools/source/val/construct.h +++ b/3rdparty/spirv-tools/source/val/construct.h @@ -12,13 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_VAL_CONSTRUCT_H_ -#define LIBSPIRV_VAL_CONSTRUCT_H_ +#ifndef SOURCE_VAL_CONSTRUCT_H_ +#define SOURCE_VAL_CONSTRUCT_H_ #include +#include #include -namespace libspirv { +#include "source/val/basic_block.h" + +namespace spvtools { +namespace val { + +/// Functor for ordering BasicBlocks. BasicBlock pointers must not be null. +struct less_than_id { + bool operator()(const BasicBlock* lhs, const BasicBlock* rhs) const { + return lhs->id() < rhs->id(); + } +}; enum class ConstructType : int { kNone = 0, @@ -39,7 +50,7 @@ enum class ConstructType : int { kCase }; -class BasicBlock; +class Function; /// @brief This class tracks the CFG constructs as defined in the SPIR-V spec class Construct { @@ -91,6 +102,13 @@ class Construct { return type_ == ConstructType::kLoop || type_ == ConstructType::kSelection; } + using ConstructBlockSet = std::set; + + // Returns the basic blocks in this construct. This function should not + // be called before the exit block is set and dominators have been + // calculated. + ConstructBlockSet blocks(Function* function) const; + private: /// The type of the construct ConstructType type_; @@ -127,6 +145,7 @@ class Construct { BasicBlock* exit_block_; }; -} // namespace libspirv +} // namespace val +} // namespace spvtools -#endif /// LIBSPIRV_VAL_CONSTRUCT_H_ +#endif // SOURCE_VAL_CONSTRUCT_H_ diff --git a/3rdparty/spirv-tools/source/val/decoration.h b/3rdparty/spirv-tools/source/val/decoration.h index 8d2899157..ed3320f87 100644 --- a/3rdparty/spirv-tools/source/val/decoration.h +++ b/3rdparty/spirv-tools/source/val/decoration.h @@ -12,16 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_VAL_DECORATION_H_ -#define LIBSPIRV_VAL_DECORATION_H_ +#ifndef SOURCE_VAL_DECORATION_H_ +#define SOURCE_VAL_DECORATION_H_ #include #include #include -#include "latest_version_spirv_header.h" +#include "source/latest_version_spirv_header.h" -namespace libspirv { +namespace spvtools { +namespace val { // An object of this class represents a specific decoration including its // parameters (if any). Decorations are used by OpDecorate and OpMemberDecorate, @@ -82,6 +83,7 @@ class Decoration { int struct_member_index_; }; -} // namespace libspirv +} // namespace val +} // namespace spvtools -#endif /// LIBSPIRV_VAL_DECORATION_H_ +#endif // SOURCE_VAL_DECORATION_H_ diff --git a/3rdparty/spirv-tools/source/val/function.cpp b/3rdparty/spirv-tools/source/val/function.cpp index 7f49ad0c2..f638fb5b4 100644 --- a/3rdparty/spirv-tools/source/val/function.cpp +++ b/3rdparty/spirv-tools/source/val/function.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "val/function.h" +#include "source/val/function.h" #include @@ -22,19 +22,13 @@ #include #include -#include "cfa.h" -#include "val/basic_block.h" -#include "val/construct.h" -#include "validate.h" +#include "source/cfa.h" +#include "source/val/basic_block.h" +#include "source/val/construct.h" +#include "source/val/validate.h" -using std::ignore; -using std::list; -using std::make_pair; -using std::pair; -using std::tie; -using std::vector; - -namespace libspirv { +namespace spvtools { +namespace val { // Universal Limit of ResultID + 1 static const uint32_t kInvalidId = 0x400000; @@ -139,12 +133,12 @@ spv_result_t Function::RegisterBlock(uint32_t block_id, bool is_definition) { return SPV_SUCCESS; } -void Function::RegisterBlockEnd(vector next_list, +void Function::RegisterBlockEnd(std::vector next_list, SpvOp branch_instruction) { assert( current_block_ && "RegisterBlockEnd can only be called when parsing a binary in a block"); - vector next_blocks; + std::vector next_blocks; next_blocks.reserve(next_list.size()); std::unordered_map::iterator inserted_block; @@ -195,16 +189,18 @@ size_t Function::undefined_block_count() const { return undefined_blocks_.size(); } -const vector& Function::ordered_blocks() const { +const std::vector& Function::ordered_blocks() const { return ordered_blocks_; } -vector& Function::ordered_blocks() { return ordered_blocks_; } +std::vector& Function::ordered_blocks() { return ordered_blocks_; } const BasicBlock* Function::current_block() const { return current_block_; } BasicBlock* Function::current_block() { return current_block_; } -const list& Function::constructs() const { return cfg_constructs_; } -list& Function::constructs() { return cfg_constructs_; } +const std::list& Function::constructs() const { + return cfg_constructs_; +} +std::list& Function::constructs() { return cfg_constructs_; } const BasicBlock* Function::first_block() const { if (ordered_blocks_.empty()) return nullptr; @@ -218,30 +214,31 @@ BasicBlock* Function::first_block() { bool Function::IsBlockType(uint32_t merge_block_id, BlockType type) const { bool ret = false; const BasicBlock* block; - tie(block, ignore) = GetBlock(merge_block_id); + std::tie(block, std::ignore) = GetBlock(merge_block_id); if (block) { ret = block->is_type(type); } return ret; } -pair Function::GetBlock(uint32_t block_id) const { +std::pair Function::GetBlock(uint32_t block_id) const { const auto b = blocks_.find(block_id); if (b != end(blocks_)) { const BasicBlock* block = &(b->second); bool defined = - undefined_blocks_.find(block->id()) == end(undefined_blocks_); - return make_pair(block, defined); + undefined_blocks_.find(block->id()) == std::end(undefined_blocks_); + return std::make_pair(block, defined); } else { - return make_pair(nullptr, false); + return std::make_pair(nullptr, false); } } -pair Function::GetBlock(uint32_t block_id) { +std::pair Function::GetBlock(uint32_t block_id) { const BasicBlock* out; bool defined; - tie(out, defined) = const_cast(this)->GetBlock(block_id); - return make_pair(const_cast(out), defined); + std::tie(out, defined) = + const_cast(this)->GetBlock(block_id); + return std::make_pair(const_cast(out), defined); } Function::GetBlocksFunction Function::AugmentedCFGSuccessorsFunction() const { @@ -275,7 +272,7 @@ void Function::ComputeAugmentedCFG() { // the predecessors of the pseudo exit block. auto succ_func = [](const BasicBlock* b) { return b->successors(); }; auto pred_func = [](const BasicBlock* b) { return b->predecessors(); }; - spvtools::CFA::ComputeAugmentedCFG( + CFA::ComputeAugmentedCFG( ordered_blocks_, &pseudo_entry_block_, &pseudo_exit_block_, &augmented_successors_map_, &augmented_predecessors_map_, succ_func, pred_func); @@ -386,4 +383,5 @@ bool Function::IsCompatibleWithExecutionModel(SpvExecutionModel model, return return_value; } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/function.h b/3rdparty/spirv-tools/source/val/function.h index 1984654f0..a052bbda0 100644 --- a/3rdparty/spirv-tools/source/val/function.h +++ b/3rdparty/spirv-tools/source/val/function.h @@ -12,23 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_VAL_FUNCTION_H_ -#define LIBSPIRV_VAL_FUNCTION_H_ +#ifndef SOURCE_VAL_FUNCTION_H_ +#define SOURCE_VAL_FUNCTION_H_ #include #include #include #include +#include #include #include +#include #include -#include "latest_version_spirv_header.h" +#include "source/latest_version_spirv_header.h" +#include "source/val/basic_block.h" +#include "source/val/construct.h" #include "spirv-tools/libspirv.h" -#include "val/basic_block.h" -#include "val/construct.h" -namespace libspirv { +namespace spvtools { +namespace val { struct bb_constr_type_pair_hash { std::size_t operator()( @@ -331,7 +334,7 @@ class Function { /// constructs, the type of the construct should also be specified in order to /// get the unique construct. std::unordered_map, Construct*, - libspirv::bb_constr_type_pair_hash> + bb_constr_type_pair_hash> entry_block_to_construct_; /// This map provides the header block for a given merge block. @@ -351,6 +354,7 @@ class Function { std::set function_call_targets_; }; -} // namespace libspirv +} // namespace val +} // namespace spvtools -#endif /// LIBSPIRV_VAL_FUNCTION_H_ +#endif // SOURCE_VAL_FUNCTION_H_ diff --git a/3rdparty/spirv-tools/source/val/instruction.cpp b/3rdparty/spirv-tools/source/val/instruction.cpp index 56bd37f5e..b9155898a 100644 --- a/3rdparty/spirv-tools/source/val/instruction.cpp +++ b/3rdparty/spirv-tools/source/val/instruction.cpp @@ -12,38 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "val/instruction.h" +#include "source/val/instruction.h" #include -using std::make_pair; +namespace spvtools { +namespace val { -namespace libspirv { -#define OPERATOR(OP) \ - bool operator OP(const Instruction& lhs, const Instruction& rhs) { \ - return lhs.id() OP rhs.id(); \ - } \ - bool operator OP(const Instruction& lhs, uint32_t rhs) { \ - return lhs.id() OP rhs; \ - } - -OPERATOR(<) -OPERATOR(==) -#undef OPERATOR - -Instruction::Instruction(const spv_parsed_instruction_t* inst, - Function* defining_function, - BasicBlock* defining_block) +Instruction::Instruction(const spv_parsed_instruction_t* inst) : words_(inst->words, inst->words + inst->num_words), operands_(inst->operands, inst->operands + inst->num_operands), inst_({words_.data(), inst->num_words, inst->opcode, inst->ext_inst_type, inst->type_id, inst->result_id, operands_.data(), - inst->num_operands}), - function_(defining_function), - block_(defining_block), - uses_() {} + inst->num_operands}) {} void Instruction::RegisterUse(const Instruction* inst, uint32_t index) { - uses_.push_back(make_pair(inst, index)); + uses_.push_back(std::make_pair(inst, index)); } -} // namespace libspirv + +bool operator<(const Instruction& lhs, const Instruction& rhs) { + return lhs.id() < rhs.id(); +} +bool operator<(const Instruction& lhs, uint32_t rhs) { return lhs.id() < rhs; } +bool operator==(const Instruction& lhs, const Instruction& rhs) { + return lhs.id() == rhs.id(); +} +bool operator==(const Instruction& lhs, uint32_t rhs) { + return lhs.id() == rhs; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/instruction.h b/3rdparty/spirv-tools/source/val/instruction.h index 96136320e..1fa855fca 100644 --- a/3rdparty/spirv-tools/source/val/instruction.h +++ b/3rdparty/spirv-tools/source/val/instruction.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_VAL_INSTRUCTION_H_ -#define LIBSPIRV_VAL_INSTRUCTION_H_ +#ifndef SOURCE_VAL_INSTRUCTION_H_ +#define SOURCE_VAL_INSTRUCTION_H_ #include #include @@ -21,10 +21,11 @@ #include #include +#include "source/table.h" #include "spirv-tools/libspirv.h" -#include "table.h" -namespace libspirv { +namespace spvtools { +namespace val { class BasicBlock; class Function; @@ -33,9 +34,7 @@ class Function; /// instruction's result id class Instruction { public: - explicit Instruction(const spv_parsed_instruction_t* inst, - Function* defining_function = nullptr, - BasicBlock* defining_block = nullptr); + explicit Instruction(const spv_parsed_instruction_t* inst); /// Registers the use of the Instruction in instruction \p inst at \p index void RegisterUse(const Instruction* inst, uint32_t index); @@ -47,10 +46,12 @@ class Instruction { /// Returns the Function where the instruction was defined. nullptr if it was /// defined outside of a Function const Function* function() const { return function_; } + void set_function(Function* func) { function_ = func; } /// Returns the BasicBlock where the instruction was defined. nullptr if it /// was defined outside of a BasicBlock const BasicBlock* block() const { return block_; } + void set_block(BasicBlock* b) { block_ = b; } /// Returns a vector of pairs of all references to this instruction's result /// id. The first element is the instruction in which this result id was @@ -66,6 +67,11 @@ class Instruction { /// The words used to define the Instruction const std::vector& words() const { return words_; } + /// Returns the operand at |idx|. + const spv_parsed_operand_t& operand(size_t idx) const { + return operands_[idx]; + } + /// The operands of the Instruction const std::vector& operands() const { return operands_; @@ -74,25 +80,34 @@ class Instruction { /// Provides direct access to the stored C instruction object. const spv_parsed_instruction_t& c_inst() const { return inst_; } + /// Provides direct access to instructions spv_ext_inst_type_t object. + const spv_ext_inst_type_t& ext_inst_type() const { + return inst_.ext_inst_type; + } + // Casts the words belonging to the operand under |index| to |T| and returns. template T GetOperandAs(size_t index) const { - const spv_parsed_operand_t& operand = operands_.at(index); - assert(operand.num_words * 4 >= sizeof(T)); - assert(operand.offset + operand.num_words <= inst_.num_words); - return *reinterpret_cast(&words_[operand.offset]); + const spv_parsed_operand_t& o = operands_.at(index); + assert(o.num_words * 4 >= sizeof(T)); + assert(o.offset + o.num_words <= inst_.num_words); + return *reinterpret_cast(&words_[o.offset]); } + size_t LineNum() const { return line_num_; } + void SetLineNum(size_t pos) { line_num_ = pos; } + private: const std::vector words_; const std::vector operands_; spv_parsed_instruction_t inst_; + size_t line_num_ = 0; /// The function in which this instruction was declared - Function* function_; + Function* function_ = nullptr; /// The basic block in which this instruction was declared - BasicBlock* block_; + BasicBlock* block_ = nullptr; /// This is a vector of pairs of all references to this instruction's result /// id. The first element is the instruction in which this result id was @@ -101,26 +116,25 @@ class Instruction { std::vector> uses_; }; -#define OPERATOR(OP) \ - bool operator OP(const Instruction& lhs, const Instruction& rhs); \ - bool operator OP(const Instruction& lhs, uint32_t rhs) +bool operator<(const Instruction& lhs, const Instruction& rhs); +bool operator<(const Instruction& lhs, uint32_t rhs); +bool operator==(const Instruction& lhs, const Instruction& rhs); +bool operator==(const Instruction& lhs, uint32_t rhs); -OPERATOR(<); -OPERATOR(==); -#undef OPERATOR - -} // namespace libspirv +} // namespace val +} // namespace spvtools // custom specialization of std::hash for Instruction namespace std { template <> -struct hash { - typedef libspirv::Instruction argument_type; +struct hash { + typedef spvtools::val::Instruction argument_type; typedef std::size_t result_type; result_type operator()(const argument_type& inst) const { return hash()(inst.id()); } }; + } // namespace std -#endif // LIBSPIRV_VAL_INSTRUCTION_H_ +#endif // SOURCE_VAL_INSTRUCTION_H_ diff --git a/3rdparty/spirv-tools/source/val/validate.cpp b/3rdparty/spirv-tools/source/val/validate.cpp new file mode 100644 index 000000000..84ec193a1 --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate.cpp @@ -0,0 +1,429 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "source/binary.h" +#include "source/diagnostic.h" +#include "source/enum_string_mapping.h" +#include "source/extensions.h" +#include "source/instruction.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/spirv_constant.h" +#include "source/spirv_endian.h" +#include "source/spirv_target_env.h" +#include "source/spirv_validator_options.h" +#include "source/val/construct.h" +#include "source/val/function.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace val { +namespace { + +// TODO(umar): Validate header +// TODO(umar): The binary parser validates the magic word, and the length of the +// header, but nothing else. +spv_result_t setHeader(void* user_data, spv_endianness_t, uint32_t, + uint32_t version, uint32_t generator, uint32_t id_bound, + uint32_t) { + // Record the ID bound so that the validator can ensure no ID is out of bound. + ValidationState_t& _ = *(reinterpret_cast(user_data)); + _.setIdBound(id_bound); + _.setGenerator(generator); + _.setVersion(version); + + return SPV_SUCCESS; +} + +// Parses OpExtension instruction and registers extension. +void RegisterExtension(ValidationState_t& _, + const spv_parsed_instruction_t* inst) { + const std::string extension_str = spvtools::GetExtensionString(inst); + Extension extension; + if (!GetExtensionFromString(extension_str.c_str(), &extension)) { + // The error will be logged in the ProcessInstruction pass. + return; + } + + _.RegisterExtension(extension); +} + +// Parses the beginning of the module searching for OpExtension instructions. +// Registers extensions if recognized. Returns SPV_REQUESTED_TERMINATION +// once an instruction which is not SpvOpCapability and SpvOpExtension is +// encountered. According to the SPIR-V spec extensions are declared after +// capabilities and before everything else. +spv_result_t ProcessExtensions(void* user_data, + const spv_parsed_instruction_t* inst) { + const SpvOp opcode = static_cast(inst->opcode); + if (opcode == SpvOpCapability) return SPV_SUCCESS; + + if (opcode == SpvOpExtension) { + ValidationState_t& _ = *(reinterpret_cast(user_data)); + RegisterExtension(_, inst); + return SPV_SUCCESS; + } + + // OpExtension block is finished, requesting termination. + return SPV_REQUESTED_TERMINATION; +} + +spv_result_t ProcessInstruction(void* user_data, + const spv_parsed_instruction_t* inst) { + ValidationState_t& _ = *(reinterpret_cast(user_data)); + + auto* instruction = _.AddOrderedInstruction(inst); + _.RegisterDebugInstruction(instruction); + + return SPV_SUCCESS; +} + +void printDot(const ValidationState_t& _, const BasicBlock& other) { + std::string block_string; + if (other.successors()->empty()) { + block_string += "end "; + } else { + for (auto block : *other.successors()) { + block_string += _.getIdOrName(block->id()) + " "; + } + } + printf("%10s -> {%s\b}\n", _.getIdOrName(other.id()).c_str(), + block_string.c_str()); +} + +void PrintBlocks(ValidationState_t& _, Function func) { + assert(func.first_block()); + + printf("%10s -> %s\n", _.getIdOrName(func.id()).c_str(), + _.getIdOrName(func.first_block()->id()).c_str()); + for (const auto& block : func.ordered_blocks()) { + printDot(_, *block); + } +} + +#ifdef __clang__ +#define UNUSED(func) [[gnu::unused]] func +#elif defined(__GNUC__) +#define UNUSED(func) \ + func __attribute__((unused)); \ + func +#elif defined(_MSC_VER) +#define UNUSED(func) func +#endif + +UNUSED(void PrintDotGraph(ValidationState_t& _, Function func)) { + if (func.first_block()) { + std::string func_name(_.getIdOrName(func.id())); + printf("digraph %s {\n", func_name.c_str()); + PrintBlocks(_, func); + printf("}\n"); + } +} + +spv_result_t ValidateForwardDecls(ValidationState_t& _) { + if (_.unresolved_forward_id_count() == 0) return SPV_SUCCESS; + + std::stringstream ss; + std::vector ids = _.UnresolvedForwardIds(); + + std::transform( + std::begin(ids), std::end(ids), + std::ostream_iterator(ss, " "), + bind(&ValidationState_t::getIdName, std::ref(_), std::placeholders::_1)); + + auto id_str = ss.str(); + return _.diag(SPV_ERROR_INVALID_ID, nullptr) + << "The following forward referenced IDs have not been defined:\n" + << id_str.substr(0, id_str.size() - 1); +} + +// Entry point validation. Based on 2.16.1 (Universal Validation Rules) of the +// SPIRV spec: +// * There is at least one OpEntryPoint instruction, unless the Linkage +// capability is being used. +// * No function can be targeted by both an OpEntryPoint instruction and an +// OpFunctionCall instruction. +spv_result_t ValidateEntryPoints(ValidationState_t& _) { + _.ComputeFunctionToEntryPointMapping(); + + if (_.entry_points().empty() && !_.HasCapability(SpvCapabilityLinkage)) { + return _.diag(SPV_ERROR_INVALID_BINARY, nullptr) + << "No OpEntryPoint instruction was found. This is only allowed if " + "the Linkage capability is being used."; + } + for (const auto& entry_point : _.entry_points()) { + if (_.IsFunctionCallTarget(entry_point)) { + return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point)) + << "A function (" << entry_point + << ") may not be targeted by both an OpEntryPoint instruction and " + "an OpFunctionCall instruction."; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateBinaryUsingContextAndValidationState( + const spv_context_t& context, const uint32_t* words, const size_t num_words, + spv_diagnostic* pDiagnostic, ValidationState_t* vstate) { + auto binary = std::unique_ptr( + new spv_const_binary_t{words, num_words}); + + spv_endianness_t endian; + spv_position_t position = {}; + if (spvBinaryEndianness(binary.get(), &endian)) { + return DiagnosticStream(position, context.consumer, "", + SPV_ERROR_INVALID_BINARY) + << "Invalid SPIR-V magic number."; + } + + spv_header_t header; + if (spvBinaryHeaderGet(binary.get(), endian, &header)) { + return DiagnosticStream(position, context.consumer, "", + SPV_ERROR_INVALID_BINARY) + << "Invalid SPIR-V header."; + } + + if (header.version > spvVersionForTargetEnv(context.target_env)) { + return DiagnosticStream(position, context.consumer, "", + SPV_ERROR_WRONG_VERSION) + << "Invalid SPIR-V binary version " + << SPV_SPIRV_VERSION_MAJOR_PART(header.version) << "." + << SPV_SPIRV_VERSION_MINOR_PART(header.version) + << " for target environment " + << spvTargetEnvDescription(context.target_env) << "."; + } + + // Look for OpExtension instructions and register extensions. + spvBinaryParse(&context, vstate, words, num_words, + /* parsed_header = */ nullptr, ProcessExtensions, + /* diagnostic = */ nullptr); + + // Parse the module and perform inline validation checks. These checks do + // not require the the knowledge of the whole module. + if (auto error = spvBinaryParse(&context, vstate, words, num_words, setHeader, + ProcessInstruction, pDiagnostic)) { + return error; + } + + for (auto& instruction : vstate->ordered_instructions()) { + { + // In order to do this work outside of Process Instruction we need to be + // able to, briefly, de-const the instruction. + Instruction* inst = const_cast(&instruction); + + if (inst->opcode() == SpvOpEntryPoint) { + const auto entry_point = inst->GetOperandAs(1); + const auto execution_model = inst->GetOperandAs(0); + const char* str = reinterpret_cast( + inst->words().data() + inst->operand(2).offset); + + ValidationState_t::EntryPointDescription desc; + desc.name = str; + + std::vector interfaces; + for (size_t j = 3; j < inst->operands().size(); ++j) + desc.interfaces.push_back(inst->word(inst->operand(j).offset)); + + vstate->RegisterEntryPoint(entry_point, execution_model, + std::move(desc)); + } + if (inst->opcode() == SpvOpFunctionCall) { + if (!vstate->in_function_body()) { + return vstate->diag(SPV_ERROR_INVALID_LAYOUT, &instruction) + << "A FunctionCall must happen within a function body."; + } + + vstate->AddFunctionCallTarget(inst->GetOperandAs(2)); + } + + if (vstate->in_function_body()) { + inst->set_function(&(vstate->current_function())); + inst->set_block(vstate->current_function().current_block()); + + if (vstate->in_block() && spvOpcodeIsBlockTerminator(inst->opcode())) { + vstate->current_function().current_block()->set_terminator(inst); + } + } + + if (auto error = IdPass(*vstate, inst)) return error; + } + + if (auto error = CapabilityPass(*vstate, &instruction)) return error; + if (auto error = DataRulesPass(*vstate, &instruction)) return error; + if (auto error = ModuleLayoutPass(*vstate, &instruction)) return error; + if (auto error = CfgPass(*vstate, &instruction)) return error; + if (auto error = InstructionPass(*vstate, &instruction)) return error; + + // Now that all of the checks are done, update the state. + { + Instruction* inst = const_cast(&instruction); + vstate->RegisterInstruction(inst); + } + if (auto error = UpdateIdUse(*vstate, &instruction)) return error; + } + + if (!vstate->has_memory_model_specified()) + return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr) + << "Missing required OpMemoryModel instruction."; + + if (vstate->in_function_body()) + return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr) + << "Missing OpFunctionEnd at end of module."; + + // Catch undefined forward references before performing further checks. + if (auto error = ValidateForwardDecls(*vstate)) return error; + + // Validate individual opcodes. + for (size_t i = 0; i < vstate->ordered_instructions().size(); ++i) { + auto& instruction = vstate->ordered_instructions()[i]; + + // Keep these passes in the order they appear in the SPIR-V specification + // sections to maintain test consistency. + // Miscellaneous + if (auto error = DebugPass(*vstate, &instruction)) return error; + if (auto error = AnnotationPass(*vstate, &instruction)) return error; + if (auto error = ExtInstPass(*vstate, &instruction)) return error; + if (auto error = ModeSettingPass(*vstate, &instruction)) return error; + if (auto error = TypePass(*vstate, &instruction)) return error; + if (auto error = ConstantPass(*vstate, &instruction)) return error; + if (auto error = ValidateMemoryInstructions(*vstate, &instruction)) + return error; + if (auto error = FunctionPass(*vstate, &instruction)) return error; + if (auto error = ImagePass(*vstate, &instruction)) return error; + if (auto error = ConversionPass(*vstate, &instruction)) return error; + if (auto error = CompositesPass(*vstate, &instruction)) return error; + if (auto error = ArithmeticsPass(*vstate, &instruction)) return error; + if (auto error = BitwisePass(*vstate, &instruction)) return error; + if (auto error = LogicalsPass(*vstate, &instruction)) return error; + if (auto error = ControlFlowPass(*vstate, &instruction)) return error; + if (auto error = DerivativesPass(*vstate, &instruction)) return error; + if (auto error = AtomicsPass(*vstate, &instruction)) return error; + if (auto error = PrimitivesPass(*vstate, &instruction)) return error; + if (auto error = BarriersPass(*vstate, &instruction)) return error; + // Group + // Device-Side Enqueue + // Pipe + if (auto error = NonUniformPass(*vstate, &instruction)) return error; + + if (auto error = LiteralsPass(*vstate, &instruction)) return error; + // Validate the preconditions involving adjacent instructions. e.g. SpvOpPhi + // must only be preceeded by SpvOpLabel, SpvOpPhi, or SpvOpLine. + if (auto error = ValidateAdjacency(*vstate, i)) return error; + } + + if (auto error = ValidateEntryPoints(*vstate)) return error; + // CFG checks are performed after the binary has been parsed + // and the CFGPass has collected information about the control flow + if (auto error = PerformCfgChecks(*vstate)) return error; + if (auto error = CheckIdDefinitionDominateUse(*vstate)) return error; + if (auto error = ValidateDecorations(*vstate)) return error; + if (auto error = ValidateInterfaces(*vstate)) return error; + // TODO(dsinclair): Restructure ValidateBuiltins so we can move into the + // for() above as it loops over all ordered_instructions internally. + if (auto error = ValidateBuiltIns(*vstate)) return error; + // These checks must be performed after individual opcode checks because + // those checks register the limitation checked here. + for (const auto inst : vstate->ordered_instructions()) { + if (auto error = ValidateExecutionLimitations(*vstate, &inst)) return error; + } + + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t ValidateBinaryAndKeepValidationState( + const spv_const_context context, spv_const_validator_options options, + const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic, + std::unique_ptr* vstate) { + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + + vstate->reset( + new ValidationState_t(&hijack_context, options, words, num_words)); + + return ValidateBinaryUsingContextAndValidationState( + hijack_context, words, num_words, pDiagnostic, vstate->get()); +} + +} // namespace val +} // namespace spvtools + +spv_result_t spvValidate(const spv_const_context context, + const spv_const_binary binary, + spv_diagnostic* pDiagnostic) { + return spvValidateBinary(context, binary->code, binary->wordCount, + pDiagnostic); +} + +spv_result_t spvValidateBinary(const spv_const_context context, + const uint32_t* words, const size_t num_words, + spv_diagnostic* pDiagnostic) { + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + + // This interface is used for default command line options. + spv_validator_options default_options = spvValidatorOptionsCreate(); + + // Create the ValidationState using the context and default options. + spvtools::val::ValidationState_t vstate(&hijack_context, default_options, + words, num_words); + + spv_result_t result = + spvtools::val::ValidateBinaryUsingContextAndValidationState( + hijack_context, words, num_words, pDiagnostic, &vstate); + + spvValidatorOptionsDestroy(default_options); + return result; +} + +spv_result_t spvValidateWithOptions(const spv_const_context context, + spv_const_validator_options options, + const spv_const_binary binary, + spv_diagnostic* pDiagnostic) { + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + + // Create the ValidationState using the context. + spvtools::val::ValidationState_t vstate(&hijack_context, options, + binary->code, binary->wordCount); + + return spvtools::val::ValidateBinaryUsingContextAndValidationState( + hijack_context, binary->code, binary->wordCount, pDiagnostic, &vstate); +} diff --git a/3rdparty/spirv-tools/source/validate.h b/3rdparty/spirv-tools/source/val/validate.h similarity index 59% rename from 3rdparty/spirv-tools/source/validate.h rename to 3rdparty/spirv-tools/source/val/validate.h index a4f6dde28..4599c4a86 100644 --- a/3rdparty/spirv-tools/source/validate.h +++ b/3rdparty/spirv-tools/source/val/validate.h @@ -12,22 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_VALIDATE_H_ -#define LIBSPIRV_VALIDATE_H_ +#ifndef SOURCE_VAL_VALIDATE_H_ +#define SOURCE_VAL_VALIDATE_H_ #include +#include #include #include -#include "instruction.h" -#include "message.h" +#include "source/instruction.h" +#include "source/table.h" #include "spirv-tools/libspirv.h" -#include "table.h" -namespace libspirv { +namespace spvtools { +namespace val { class ValidationState_t; class BasicBlock; +class Instruction; /// A function that returns a vector of BasicBlocks given a BasicBlock. Used to /// get the successor and predecessor nodes of a CFG block @@ -49,7 +51,7 @@ spv_result_t PerformCfgChecks(ValidationState_t& _); /// @param[in] _ the validation state of the module /// /// @return SPV_SUCCESS if no errors are found. -spv_result_t UpdateIdUse(ValidationState_t& _); +spv_result_t UpdateIdUse(ValidationState_t& _, const Instruction* inst); /// @brief This function checks all ID definitions dominate their use in the /// CFG. @@ -73,7 +75,24 @@ spv_result_t CheckIdDefinitionDominateUse(const ValidationState_t& _); /// @param[in] _ the validation state of the module /// /// @return SPV_SUCCESS if no errors are found. SPV_ERROR_INVALID_DATA otherwise -spv_result_t ValidateAdjacency(ValidationState_t& _); +spv_result_t ValidateAdjacency(ValidationState_t& _, size_t idx); + +/// @brief Validates static uses of input and output variables +/// +/// Checks that any entry point that uses a input or output variable lists that +/// variable in its interface. +/// +/// @param[in] _ the validation state of the module +/// +/// @return SPV_SUCCESS if no errors are found. +spv_result_t ValidateInterfaces(ValidationState_t& _); + +/// @brief Validates memory instructions +/// +/// @param[in] _ the validation state of the module +/// @return SPV_SUCCESS if no errors are found. +spv_result_t ValidateMemoryInstructions(ValidationState_t& _, + const Instruction* inst); /// @brief Updates the immediate dominator for each of the block edges /// @@ -94,25 +113,24 @@ void printDominatorList(BasicBlock& block); /// Performs logical layout validation as described in section 2.4 of the SPIR-V /// spec. -spv_result_t ModuleLayoutPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t ModuleLayoutPass(ValidationState_t& _, const Instruction* inst); -/// Performs Control Flow Graph validation of a module -spv_result_t CfgPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +/// Performs Control Flow Graph validation and construction. +spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst); + +/// Validates Control Flow Graph instructions. +spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst); /// Performs Id and SSA validation of a module -spv_result_t IdPass(ValidationState_t& _, const spv_parsed_instruction_t* inst); +spv_result_t IdPass(ValidationState_t& _, Instruction* inst); /// Performs validation of the Data Rules subsection of 2.16.1 Universal /// Validation Rules. /// TODO(ehsann): add more comments here as more validation code is added. -spv_result_t DataRulesPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t DataRulesPass(ValidationState_t& _, const Instruction* inst); /// Performs instruction validation. -spv_result_t InstructionPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst); /// Performs decoration validation. spv_result_t ValidateDecorations(ValidationState_t& _); @@ -120,79 +138,72 @@ spv_result_t ValidateDecorations(ValidationState_t& _); /// Performs validation of built-in variables. spv_result_t ValidateBuiltIns(const ValidationState_t& _); -/// Validates that type declarations are unique, unless multiple declarations -/// of the same data type are allowed by the specification. -/// (see section 2.8 Types and Variables) -spv_result_t TypeUniquePass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +/// Validates type instructions. +spv_result_t TypePass(ValidationState_t& _, const Instruction* inst); + +/// Validates constant instructions. +spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of arithmetic instructions. -spv_result_t ArithmeticsPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of composite instructions. -spv_result_t CompositesPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of conversion instructions. -spv_result_t ConversionPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of derivative instructions. -spv_result_t DerivativesPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of logical instructions. -spv_result_t LogicalsPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t LogicalsPass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of bitwise instructions. -spv_result_t BitwisePass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of image instructions. -spv_result_t ImagePass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of atomic instructions. -spv_result_t AtomicsPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of barrier instructions. -spv_result_t BarriersPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t BarriersPass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of literal numbers. -spv_result_t LiteralsPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t LiteralsPass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of ExtInst instructions. -spv_result_t ExtInstPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t ExtInstPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of annotation instructions. +spv_result_t AnnotationPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of non-uniform group instructions. +spv_result_t NonUniformPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of debug instructions. +spv_result_t DebugPass(ValidationState_t& _, const Instruction* inst); // Validates that capability declarations use operands allowed in the current // context. -spv_result_t CapabilityPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t CapabilityPass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of primitive instructions. -spv_result_t PrimitivesPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst); +spv_result_t PrimitivesPass(ValidationState_t& _, const Instruction* inst); -} // namespace libspirv +/// Validates correctness of mode setting instructions. +spv_result_t ModeSettingPass(ValidationState_t& _, const Instruction* inst); -/// @brief Validate the ID usage of the instruction stream +/// Validates correctness of function instructions. +spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst); + +/// Validates execution limitations. /// -/// @param[in] pInsts stream of instructions -/// @param[in] instCount number of instructions -/// @param[in] usedefs use-def info from module parsing -/// @param[in,out] position current position in the stream -/// -/// @return result code -spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts, - const uint64_t instCount, - const libspirv::ValidationState_t& state, - spv_position position); +/// Verifies execution models are allowed for all functionality they contain. +spv_result_t ValidateExecutionLimitations(ValidationState_t& _, + const Instruction* inst); /// @brief Validate the ID's within a SPIR-V binary /// @@ -206,9 +217,8 @@ spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts, spv_result_t spvValidateIDs(const spv_instruction_t* pInstructions, const uint64_t count, const uint32_t bound, spv_position position, - const spvtools::MessageConsumer& consumer); + const MessageConsumer& consumer); -namespace spvtools { // Performs validation for the SPIRV-V module binary. // The main difference between this API and spvValidateBinary is that the // "Validation State" is not destroyed upon function return; it lives on and is @@ -216,13 +226,9 @@ namespace spvtools { spv_result_t ValidateBinaryAndKeepValidationState( const spv_const_context context, spv_const_validator_options options, const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic, - std::unique_ptr* vstate); - -// Performs validation for a single instruction and updates given validation -// state. -spv_result_t ValidateInstructionAndUpdateValidationState( - libspirv::ValidationState_t* vstate, const spv_parsed_instruction_t* inst); + std::unique_ptr* vstate); +} // namespace val } // namespace spvtools -#endif // LIBSPIRV_VALIDATE_H_ +#endif // SOURCE_VAL_VALIDATE_H_ diff --git a/3rdparty/spirv-tools/source/val/validate_adjacency.cpp b/3rdparty/spirv-tools/source/val/validate_adjacency.cpp new file mode 100644 index 000000000..5ef56be99 --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_adjacency.cpp @@ -0,0 +1,86 @@ +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of the intra-block preconditions of SPIR-V +// instructions. + +#include "source/val/validate.h" + +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +spv_result_t ValidateAdjacency(ValidationState_t& _, size_t idx) { + const auto& instructions = _.ordered_instructions(); + const auto& inst = instructions[idx]; + + switch (inst.opcode()) { + case SpvOpPhi: + if (idx > 0) { + switch (instructions[idx - 1].opcode()) { + case SpvOpLabel: + case SpvOpPhi: + case SpvOpLine: + break; + default: + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "OpPhi must appear before all non-OpPhi instructions " + << "(except for OpLine, which can be mixed with OpPhi)."; + } + } + break; + case SpvOpLoopMerge: + if (idx != (instructions.size() - 1)) { + switch (instructions[idx + 1].opcode()) { + case SpvOpBranch: + case SpvOpBranchConditional: + break; + default: + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "OpLoopMerge must immediately precede either an " + << "OpBranch or OpBranchConditional instruction. " + << "OpLoopMerge must be the second-to-last instruction in " + << "its block."; + } + } + break; + case SpvOpSelectionMerge: + if (idx != (instructions.size() - 1)) { + switch (instructions[idx + 1].opcode()) { + case SpvOpBranchConditional: + case SpvOpSwitch: + break; + default: + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "OpSelectionMerge must immediately precede either an " + << "OpBranchConditional or OpSwitch instruction. " + << "OpSelectionMerge must be the second-to-last " + << "instruction in its block."; + } + } + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_annotation.cpp b/3rdparty/spirv-tools/source/val/validate_annotation.cpp new file mode 100644 index 000000000..f1758391d --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_annotation.cpp @@ -0,0 +1,158 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateDecorate(ValidationState_t& _, const Instruction* inst) { + const auto decoration = inst->GetOperandAs(1); + if (decoration == SpvDecorationSpecId) { + const auto target_id = inst->GetOperandAs(0); + const auto target = _.FindDef(target_id); + if (!target || !spvOpcodeIsScalarSpecConstant(target->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpDecorate SpecId decoration target '" + << _.getIdName(decoration) + << "' is not a scalar specialization constant."; + } + } + // TODO: Add validations for all decorations. + return SPV_SUCCESS; +} + +spv_result_t ValidateMemberDecorate(ValidationState_t& _, + const Instruction* inst) { + const auto struct_type_id = inst->GetOperandAs(0); + const auto struct_type = _.FindDef(struct_type_id); + if (!struct_type || SpvOpTypeStruct != struct_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpMemberDecorate Structure type '" + << _.getIdName(struct_type_id) << "' is not a struct type."; + } + const auto member = inst->GetOperandAs(1); + const auto member_count = + static_cast(struct_type->words().size() - 2); + if (member_count < member) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Index " << member + << " provided in OpMemberDecorate for struct " + << _.getIdName(struct_type_id) + << " is out of bounds. The structure has " << member_count + << " members. Largest valid index is " << member_count - 1 << "."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateDecorationGroup(ValidationState_t& _, + const Instruction* inst) { + const auto decoration_group_id = inst->GetOperandAs(0); + const auto decoration_group = _.FindDef(decoration_group_id); + for (auto pair : decoration_group->uses()) { + auto use = pair.first; + if (use->opcode() != SpvOpDecorate && use->opcode() != SpvOpGroupDecorate && + use->opcode() != SpvOpGroupMemberDecorate && + use->opcode() != SpvOpName) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Result id of OpDecorationGroup can only " + << "be targeted by OpName, OpGroupDecorate, " + << "OpDecorate, and OpGroupMemberDecorate"; + } + } + return SPV_SUCCESS; +} + +spv_result_t ValidateGroupDecorate(ValidationState_t& _, + const Instruction* inst) { + const auto decoration_group_id = inst->GetOperandAs(0); + auto decoration_group = _.FindDef(decoration_group_id); + if (!decoration_group || SpvOpDecorationGroup != decoration_group->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpGroupDecorate Decoration group '" + << _.getIdName(decoration_group_id) + << "' is not a decoration group."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateGroupMemberDecorate(ValidationState_t& _, + const Instruction* inst) { + const auto decoration_group_id = inst->GetOperandAs(0); + const auto decoration_group = _.FindDef(decoration_group_id); + if (!decoration_group || SpvOpDecorationGroup != decoration_group->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpGroupMemberDecorate Decoration group '" + << _.getIdName(decoration_group_id) + << "' is not a decoration group."; + } + // Grammar checks ensures that the number of arguments to this instruction + // is an odd number: 1 decoration group + (id,literal) pairs. + for (size_t i = 1; i + 1 < inst->operands().size(); i += 2) { + const uint32_t struct_id = inst->GetOperandAs(i); + const uint32_t index = inst->GetOperandAs(i + 1); + auto struct_instr = _.FindDef(struct_id); + if (!struct_instr || SpvOpTypeStruct != struct_instr->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpGroupMemberDecorate Structure type '" + << _.getIdName(struct_id) << "' is not a struct type."; + } + const uint32_t num_struct_members = + static_cast(struct_instr->words().size() - 2); + if (index >= num_struct_members) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Index " << index + << " provided in OpGroupMemberDecorate for struct " + << _.getIdName(struct_id) + << " is out of bounds. The structure has " << num_struct_members + << " members. Largest valid index is " << num_struct_members - 1 + << "."; + } + } + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t AnnotationPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpDecorate: + if (auto error = ValidateDecorate(_, inst)) return error; + break; + case SpvOpMemberDecorate: + if (auto error = ValidateMemberDecorate(_, inst)) return error; + break; + case SpvOpDecorationGroup: + if (auto error = ValidateDecorationGroup(_, inst)) return error; + break; + case SpvOpGroupDecorate: + if (auto error = ValidateGroupDecorate(_, inst)) return error; + break; + case SpvOpGroupMemberDecorate: + if (auto error = ValidateGroupMemberDecorate(_, inst)) return error; + break; + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_arithmetics.cpp b/3rdparty/spirv-tools/source/val/validate_arithmetics.cpp similarity index 73% rename from 3rdparty/spirv-tools/source/validate_arithmetics.cpp rename to 3rdparty/spirv-tools/source/val/validate_arithmetics.cpp index 783dbf4de..2314e7dfc 100644 --- a/3rdparty/spirv-tools/source/validate_arithmetics.cpp +++ b/3rdparty/spirv-tools/source/val/validate_arithmetics.cpp @@ -14,41 +14,22 @@ // Performs validation of arithmetic instructions. -#include "validate.h" +#include "source/val/validate.h" -#include "diagnostic.h" -#include "opcode.h" -#include "val/instruction.h" -#include "val/validation_state.h" +#include -namespace libspirv { +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" -namespace { - -// Returns operand word for given instruction and operand index. -// The operand is expected to only have one word. -inline uint32_t GetOperandWord(const spv_parsed_instruction_t* inst, - size_t operand_index) { - assert(operand_index < inst->num_operands); - const spv_parsed_operand_t& operand = inst->operands[operand_index]; - assert(operand.num_words == 1); - return inst->words[operand.offset]; -} - -// Returns the type id of instruction operand at |operand_index|. -// The operand is expected to be an id. -inline uint32_t GetOperandTypeId(ValidationState_t& _, - const spv_parsed_instruction_t* inst, - size_t operand_index) { - return _.GetTypeId(GetOperandWord(inst, operand_index)); -} -} // namespace +namespace spvtools { +namespace val { // Validates correctness of arithmetic instructions. -spv_result_t ArithmeticsPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - const uint32_t result_type = inst->type_id; +spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); switch (opcode) { case SpvOpFAdd: @@ -60,14 +41,14 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, case SpvOpFNegate: { if (!_.IsFloatScalarType(result_type) && !_.IsFloatVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected floating scalar or vector type as Result Type: " << spvOpcodeString(opcode); - for (size_t operand_index = 2; operand_index < inst->num_operands; + for (size_t operand_index = 2; operand_index < inst->operands().size(); ++operand_index) { - if (GetOperandTypeId(_, inst, operand_index) != result_type) - return _.diag(SPV_ERROR_INVALID_DATA) + if (_.GetOperandTypeId(inst, operand_index) != result_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected arithmetic operands to be of Result Type: " << spvOpcodeString(opcode) << " operand index " << operand_index; @@ -79,14 +60,14 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, case SpvOpUMod: { if (!_.IsUnsignedIntScalarType(result_type) && !_.IsUnsignedIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected unsigned int scalar or vector type as Result Type: " << spvOpcodeString(opcode); - for (size_t operand_index = 2; operand_index < inst->num_operands; + for (size_t operand_index = 2; operand_index < inst->operands().size(); ++operand_index) { - if (GetOperandTypeId(_, inst, operand_index) != result_type) - return _.diag(SPV_ERROR_INVALID_DATA) + if (_.GetOperandTypeId(inst, operand_index) != result_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected arithmetic operands to be of Result Type: " << spvOpcodeString(opcode) << " operand index " << operand_index; @@ -102,31 +83,31 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, case SpvOpSRem: case SpvOpSNegate: { if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t dimension = _.GetDimension(result_type); const uint32_t bit_width = _.GetBitWidth(result_type); - for (size_t operand_index = 2; operand_index < inst->num_operands; + for (size_t operand_index = 2; operand_index < inst->operands().size(); ++operand_index) { - const uint32_t type_id = GetOperandTypeId(_, inst, operand_index); + const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); if (!type_id || (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as operand: " << spvOpcodeString(opcode) << " operand index " << operand_index; if (_.GetDimension(type_id) != dimension) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected arithmetic operands to have the same dimension " << "as Result Type: " << spvOpcodeString(opcode) << " operand index " << operand_index; if (_.GetBitWidth(type_id) != bit_width) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected arithmetic operands to have the same bit width " << "as Result Type: " << spvOpcodeString(opcode) << " operand index " << operand_index; @@ -136,25 +117,25 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, case SpvOpDot: { if (!_.IsFloatScalarType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float scalar type as Result Type: " << spvOpcodeString(opcode); uint32_t first_vector_num_components = 0; - for (size_t operand_index = 2; operand_index < inst->num_operands; + for (size_t operand_index = 2; operand_index < inst->operands().size(); ++operand_index) { - const uint32_t type_id = GetOperandTypeId(_, inst, operand_index); + const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); if (!type_id || !_.IsFloatVectorType(type_id)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float vector as operand: " << spvOpcodeString(opcode) << " operand index " << operand_index; const uint32_t component_type = _.GetComponentType(type_id); if (component_type != result_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected component type to be equal to Result Type: " << spvOpcodeString(opcode) << " operand index " << operand_index; @@ -163,7 +144,7 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, if (operand_index == 2) { first_vector_num_components = num_components; } else if (num_components != first_vector_num_components) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected operands to have the same number of componenets: " << spvOpcodeString(opcode); } @@ -173,21 +154,21 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, case SpvOpVectorTimesScalar: { if (!_.IsFloatVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float vector type as Result Type: " << spvOpcodeString(opcode); - const uint32_t vector_type_id = GetOperandTypeId(_, inst, 2); + const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2); if (result_type != vector_type_id) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected vector operand type to be equal to Result Type: " << spvOpcodeString(opcode); const uint32_t component_type = _.GetComponentType(vector_type_id); - const uint32_t scalar_type_id = GetOperandTypeId(_, inst, 3); + const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3); if (component_type != scalar_type_id) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected scalar operand type to be equal to the component " << "type of the vector operand: " << spvOpcodeString(opcode); @@ -196,21 +177,21 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, case SpvOpMatrixTimesScalar: { if (!_.IsFloatMatrixType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float matrix type as Result Type: " << spvOpcodeString(opcode); - const uint32_t matrix_type_id = GetOperandTypeId(_, inst, 2); + const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2); if (result_type != matrix_type_id) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected matrix operand type to be equal to Result Type: " << spvOpcodeString(opcode); const uint32_t component_type = _.GetComponentType(matrix_type_id); - const uint32_t scalar_type_id = GetOperandTypeId(_, inst, 3); + const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3); if (component_type != scalar_type_id) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected scalar operand type to be equal to the component " << "type of the matrix operand: " << spvOpcodeString(opcode); @@ -218,23 +199,23 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, } case SpvOpVectorTimesMatrix: { - const uint32_t vector_type_id = GetOperandTypeId(_, inst, 2); - const uint32_t matrix_type_id = GetOperandTypeId(_, inst, 3); + const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2); + const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 3); if (!_.IsFloatVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t res_component_type = _.GetComponentType(result_type); if (!vector_type_id || !_.IsFloatVectorType(vector_type_id)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float vector type as left operand: " << spvOpcodeString(opcode); if (res_component_type != _.GetComponentType(vector_type_id)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected component types of Result Type and vector to be " << "equal: " << spvOpcodeString(opcode); @@ -245,22 +226,22 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows, &matrix_num_cols, &matrix_col_type, &matrix_component_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float matrix type as right operand: " << spvOpcodeString(opcode); if (res_component_type != matrix_component_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected component types of Result Type and matrix to be " << "equal: " << spvOpcodeString(opcode); if (matrix_num_cols != _.GetDimension(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected number of columns of the matrix to be equal to " << "Result Type vector size: " << spvOpcodeString(opcode); if (matrix_num_rows != _.GetDimension(vector_type_id)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected number of rows of the matrix to be equal to the " << "vector operand size: " << spvOpcodeString(opcode); @@ -268,11 +249,11 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, } case SpvOpMatrixTimesVector: { - const uint32_t matrix_type_id = GetOperandTypeId(_, inst, 2); - const uint32_t vector_type_id = GetOperandTypeId(_, inst, 3); + const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2); + const uint32_t vector_type_id = _.GetOperandTypeId(inst, 3); if (!_.IsFloatVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float vector type as Result Type: " << spvOpcodeString(opcode); @@ -283,28 +264,28 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows, &matrix_num_cols, &matrix_col_type, &matrix_component_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float matrix type as left operand: " << spvOpcodeString(opcode); if (result_type != matrix_col_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected column type of the matrix to be equal to Result " "Type: " << spvOpcodeString(opcode); if (!vector_type_id || !_.IsFloatVectorType(vector_type_id)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float vector type as right operand: " << spvOpcodeString(opcode); if (matrix_component_type != _.GetComponentType(vector_type_id)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected component types of the operands to be equal: " << spvOpcodeString(opcode); if (matrix_num_cols != _.GetDimension(vector_type_id)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected number of columns of the matrix to be equal to the " << "vector size: " << spvOpcodeString(opcode); @@ -312,8 +293,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, } case SpvOpMatrixTimesMatrix: { - const uint32_t left_type_id = GetOperandTypeId(_, inst, 2); - const uint32_t right_type_id = GetOperandTypeId(_, inst, 3); + const uint32_t left_type_id = _.GetOperandTypeId(inst, 2); + const uint32_t right_type_id = _.GetOperandTypeId(inst, 3); uint32_t res_num_rows = 0; uint32_t res_num_cols = 0; @@ -321,7 +302,7 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, uint32_t res_component_type = 0; if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols, &res_col_type, &res_component_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float matrix type as Result Type: " << spvOpcodeString(opcode); @@ -331,7 +312,7 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, uint32_t left_component_type = 0; if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols, &left_col_type, &left_component_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float matrix type as left operand: " << spvOpcodeString(opcode); @@ -341,34 +322,34 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, uint32_t right_component_type = 0; if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols, &right_col_type, &right_component_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float matrix type as right operand: " << spvOpcodeString(opcode); if (!_.IsFloatScalarType(res_component_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float matrix type as Result Type: " << spvOpcodeString(opcode); if (res_col_type != left_col_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected column types of Result Type and left matrix to be " << "equal: " << spvOpcodeString(opcode); if (res_component_type != right_component_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected component types of Result Type and right matrix to " "be " << "equal: " << spvOpcodeString(opcode); if (res_num_cols != right_num_cols) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected number of columns of Result Type and right matrix " "to " << "be equal: " << spvOpcodeString(opcode); if (left_num_cols != right_num_rows) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected number of columns of left matrix and number of " "rows " << "of right matrix to be equal: " << spvOpcodeString(opcode); @@ -378,8 +359,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, } case SpvOpOuterProduct: { - const uint32_t left_type_id = GetOperandTypeId(_, inst, 2); - const uint32_t right_type_id = GetOperandTypeId(_, inst, 3); + const uint32_t left_type_id = _.GetOperandTypeId(inst, 2); + const uint32_t right_type_id = _.GetOperandTypeId(inst, 3); uint32_t res_num_rows = 0; uint32_t res_num_cols = 0; @@ -387,27 +368,27 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, uint32_t res_component_type = 0; if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols, &res_col_type, &res_component_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float matrix type as Result Type: " << spvOpcodeString(opcode); if (left_type_id != res_col_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected column type of Result Type to be equal to the type " << "of the left operand: " << spvOpcodeString(opcode); if (!right_type_id || !_.IsFloatVectorType(right_type_id)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float vector type as right operand: " << spvOpcodeString(opcode); if (res_component_type != _.GetComponentType(right_type_id)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected component types of the operands to be equal: " << spvOpcodeString(opcode); if (res_num_cols != _.GetDimension(right_type_id)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected number of columns of the matrix to be equal to the " << "vector size of the right operand: " << spvOpcodeString(opcode); @@ -421,40 +402,40 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, case SpvOpSMulExtended: { std::vector result_types; if (!_.GetStructMemberTypes(result_type, &result_types)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected a struct as Result Type: " << spvOpcodeString(opcode); if (result_types.size() != 2) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type struct to have two members: " << spvOpcodeString(opcode); if (opcode == SpvOpSMulExtended) { if (!_.IsIntScalarType(result_types[0]) && !_.IsIntVectorType(result_types[0])) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type struct member types to be integer " "scalar " << "or vector: " << spvOpcodeString(opcode); } else { if (!_.IsUnsignedIntScalarType(result_types[0]) && !_.IsUnsignedIntVectorType(result_types[0])) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type struct member types to be unsigned " << "integer scalar or vector: " << spvOpcodeString(opcode); } if (result_types[0] != result_types[1]) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type struct member types to be identical: " << spvOpcodeString(opcode); - const uint32_t left_type_id = GetOperandTypeId(_, inst, 2); - const uint32_t right_type_id = GetOperandTypeId(_, inst, 3); + const uint32_t left_type_id = _.GetOperandTypeId(inst, 2); + const uint32_t right_type_id = _.GetOperandTypeId(inst, 3); if (left_type_id != result_types[0] || right_type_id != result_types[0]) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected both operands to be of Result Type member type: " << spvOpcodeString(opcode); @@ -468,4 +449,5 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, return SPV_SUCCESS; } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_atomics.cpp b/3rdparty/spirv-tools/source/val/validate_atomics.cpp similarity index 82% rename from 3rdparty/spirv-tools/source/validate_atomics.cpp rename to 3rdparty/spirv-tools/source/val/validate_atomics.cpp index 8c0c535f9..becb87200 100644 --- a/3rdparty/spirv-tools/source/validate_atomics.cpp +++ b/3rdparty/spirv-tools/source/val/validate_atomics.cpp @@ -14,28 +14,28 @@ // Validates correctness of atomic SPIR-V instructions. -#include "validate.h" +#include "source/val/validate.h" -#include "diagnostic.h" -#include "opcode.h" -#include "spirv_target_env.h" -#include "util/bitutils.h" -#include "val/instruction.h" -#include "val/validation_state.h" +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/util/bitutils.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" -namespace libspirv { +namespace spvtools { +namespace val { // Validates Memory Scope operand. -spv_result_t ValidateMemoryScope(ValidationState_t& _, - const spv_parsed_instruction_t* inst, +spv_result_t ValidateMemoryScope(ValidationState_t& _, const Instruction* inst, uint32_t id) { - const SpvOp opcode = static_cast(inst->opcode); + const SpvOp opcode = inst->opcode(); bool is_int32 = false, is_const_int32 = false; uint32_t value = 0; std::tie(is_int32, is_const_int32, value) = _.EvalInt32IfConst(id); if (!is_int32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Scope to be 32-bit int"; } @@ -48,7 +48,7 @@ spv_result_t ValidateMemoryScope(ValidationState_t& _, if (spvIsVulkanEnv(_.context()->target_env)) { if (value != SpvScopeDevice && value != SpvScopeWorkgroup && value != SpvScopeInvocation) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": in Vulkan environment memory scope is limited to Device, " "Workgroup and Invocation"; @@ -63,18 +63,17 @@ spv_result_t ValidateMemoryScope(ValidationState_t& _, // Validates a Memory Semantics operand. spv_result_t ValidateMemorySemantics(ValidationState_t& _, - const spv_parsed_instruction_t* inst, + const Instruction* inst, uint32_t operand_index) { - const SpvOp opcode = static_cast(inst->opcode); + const SpvOp opcode = inst->opcode(); bool is_int32 = false, is_const_int32 = false; uint32_t flags = 0; - const uint32_t memory_semantics_id = - inst->words[inst->operands[operand_index].offset]; + auto memory_semantics_id = inst->GetOperandAs(operand_index); std::tie(is_int32, is_const_int32, flags) = _.EvalInt32IfConst(memory_semantics_id); if (!is_int32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Memory Semantics to be 32-bit int"; } @@ -83,12 +82,12 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _, return SPV_SUCCESS; } - if (spvutils::CountSetBits( + if (spvtools::utils::CountSetBits( flags & (SpvMemorySemanticsAcquireMask | SpvMemorySemanticsReleaseMask | SpvMemorySemanticsAcquireReleaseMask | SpvMemorySemanticsSequentiallyConsistentMask)) > 1) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": no more than one of the following Memory Semantics bits can " "be set at the same time: Acquire, Release, AcquireRelease or " @@ -97,14 +96,14 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _, if (flags & SpvMemorySemanticsUniformMemoryMask && !_.HasCapability(SpvCapabilityShader)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": Memory Semantics UniformMemory requires capability Shader"; } if (flags & SpvMemorySemanticsAtomicCounterMemoryMask && !_.HasCapability(SpvCapabilityAtomicStorage)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": Memory Semantics UniformMemory requires capability " "AtomicStorage"; @@ -113,7 +112,7 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _, if (opcode == SpvOpAtomicFlagClear && (flags & SpvMemorySemanticsAcquireMask || flags & SpvMemorySemanticsAcquireReleaseMask)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Memory Semantics Acquire and AcquireRelease cannot be used with " << spvOpcodeString(opcode); } @@ -121,7 +120,7 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _, if (opcode == SpvOpAtomicCompareExchange && operand_index == 5 && (flags & SpvMemorySemanticsReleaseMask || flags & SpvMemorySemanticsAcquireReleaseMask)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": Memory Semantics Release and AcquireRelease cannot be used " "for operand Unequal"; @@ -132,7 +131,7 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _, (flags & SpvMemorySemanticsReleaseMask || flags & SpvMemorySemanticsAcquireReleaseMask || flags & SpvMemorySemanticsSequentiallyConsistentMask)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Vulkan spec disallows OpAtomicLoad with Memory Semantics " "Release, AcquireRelease and SequentiallyConsistent"; } @@ -141,7 +140,7 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _, (flags & SpvMemorySemanticsAcquireMask || flags & SpvMemorySemanticsAcquireReleaseMask || flags & SpvMemorySemanticsSequentiallyConsistentMask)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Vulkan spec disallows OpAtomicStore with Memory Semantics " "Acquire, AcquireRelease and SequentiallyConsistent"; } @@ -153,10 +152,9 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _, } // Validates correctness of atomic instructions. -spv_result_t AtomicsPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - const uint32_t result_type = inst->type_id; +spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); switch (opcode) { case SpvOpAtomicLoad: @@ -182,13 +180,13 @@ spv_result_t AtomicsPass(ValidationState_t& _, opcode == SpvOpAtomicCompareExchange)) { if (!_.IsFloatScalarType(result_type) && !_.IsIntScalarType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Result Type to be int or float scalar type"; } } else if (opcode == SpvOpAtomicFlagTestAndSet) { if (!_.IsBoolScalarType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Result Type to be bool scalar type"; } @@ -196,13 +194,13 @@ spv_result_t AtomicsPass(ValidationState_t& _, assert(result_type == 0); } else { if (!_.IsIntScalarType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Result Type to be int scalar type"; } if (spvIsVulkanEnv(_.context()->target_env) && _.GetBitWidth(result_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": according to the Vulkan spec atomic Result Type needs " "to be a 32-bit int scalar type"; @@ -216,7 +214,7 @@ spv_result_t AtomicsPass(ValidationState_t& _, uint32_t data_type = 0; uint32_t storage_class = 0; if (!_.GetPointerTypeInfo(pointer_type, &data_type, &storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Pointer to be of type OpTypePointer"; } @@ -231,7 +229,7 @@ spv_result_t AtomicsPass(ValidationState_t& _, case SpvStorageClassStorageBuffer: break; default: - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Pointer Storage Class to be Uniform, " "Workgroup, CrossWorkgroup, Generic, AtomicCounter, Image " @@ -241,28 +239,27 @@ spv_result_t AtomicsPass(ValidationState_t& _, if (opcode == SpvOpAtomicFlagTestAndSet || opcode == SpvOpAtomicFlagClear) { if (!_.IsIntScalarType(data_type) || _.GetBitWidth(data_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Pointer to point to a value of 32-bit int type"; } } else if (opcode == SpvOpAtomicStore) { if (!_.IsFloatScalarType(data_type) && !_.IsIntScalarType(data_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Pointer to be a pointer to int or float " << "scalar type"; } } else { if (data_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Pointer to point to a value of type Result " "Type"; } } - const uint32_t memory_scope = - inst->words[inst->operands[operand_index++].offset]; + auto memory_scope = inst->GetOperandAs(operand_index++); if (auto error = ValidateMemoryScope(_, inst, memory_scope)) { return error; } @@ -279,7 +276,7 @@ spv_result_t AtomicsPass(ValidationState_t& _, if (opcode == SpvOpAtomicStore) { const uint32_t value_type = _.GetOperandTypeId(inst, 3); if (value_type != data_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Value type and the type pointed to by Pointer " "to" @@ -291,7 +288,7 @@ spv_result_t AtomicsPass(ValidationState_t& _, opcode != SpvOpAtomicFlagClear) { const uint32_t value_type = _.GetOperandTypeId(inst, operand_index++); if (value_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Value to be of type Result Type"; } @@ -302,7 +299,7 @@ spv_result_t AtomicsPass(ValidationState_t& _, const uint32_t comparator_type = _.GetOperandTypeId(inst, operand_index++); if (comparator_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Comparator to be of type Result Type"; } @@ -318,4 +315,5 @@ spv_result_t AtomicsPass(ValidationState_t& _, return SPV_SUCCESS; } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_barriers.cpp b/3rdparty/spirv-tools/source/val/validate_barriers.cpp similarity index 60% rename from 3rdparty/spirv-tools/source/validate_barriers.cpp rename to 3rdparty/spirv-tools/source/val/validate_barriers.cpp index 9c77b8fc8..0771f2d28 100644 --- a/3rdparty/spirv-tools/source/validate_barriers.cpp +++ b/3rdparty/spirv-tools/source/val/validate_barriers.cpp @@ -14,31 +14,32 @@ // Validates correctness of barrier SPIR-V instructions. -#include "validate.h" +#include "source/val/validate.h" -#include "diagnostic.h" -#include "opcode.h" -#include "spirv_constant.h" -#include "spirv_target_env.h" -#include "util/bitutils.h" -#include "val/instruction.h" -#include "val/validation_state.h" +#include -namespace libspirv { +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_constant.h" +#include "source/spirv_target_env.h" +#include "source/util/bitutils.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" +namespace spvtools { +namespace val { namespace { // Validates Execution Scope operand. spv_result_t ValidateExecutionScope(ValidationState_t& _, - const spv_parsed_instruction_t* inst, - uint32_t id) { - const SpvOp opcode = static_cast(inst->opcode); + const Instruction* inst, uint32_t id) { + const SpvOp opcode = inst->opcode(); bool is_int32 = false, is_const_int32 = false; uint32_t value = 0; std::tie(is_int32, is_const_int32, value) = _.EvalInt32IfConst(id); if (!is_int32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Execution Scope to be a 32-bit int"; } @@ -49,11 +50,32 @@ spv_result_t ValidateExecutionScope(ValidationState_t& _, if (spvIsVulkanEnv(_.context()->target_env)) { if (value != SpvScopeWorkgroup && value != SpvScopeSubgroup) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": in Vulkan environment Execution Scope is limited to " "Workgroup and Subgroup"; } + + if (_.context()->target_env != SPV_ENV_VULKAN_1_0 && + value != SpvScopeSubgroup) { + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation([](SpvExecutionModel model, + std::string* message) { + if (model == SpvExecutionModelFragment || + model == SpvExecutionModelVertex || + model == SpvExecutionModelGeometry || + model == SpvExecutionModelTessellationEvaluation) { + if (message) { + *message = + "in Vulkan evironment, OpControlBarrier execution scope " + "must be Subgroup for Fragment, Vertex, Geometry and " + "TessellationEvaluation execution models"; + } + return false; + } + return true; + }); + } } // TODO(atgoo@github.com) Add checks for OpenCL and OpenGL environments. @@ -62,16 +84,15 @@ spv_result_t ValidateExecutionScope(ValidationState_t& _, } // Validates Memory Scope operand. -spv_result_t ValidateMemoryScope(ValidationState_t& _, - const spv_parsed_instruction_t* inst, +spv_result_t ValidateMemoryScope(ValidationState_t& _, const Instruction* inst, uint32_t id) { - const SpvOp opcode = static_cast(inst->opcode); + const SpvOp opcode = inst->opcode(); bool is_int32 = false, is_const_int32 = false; uint32_t value = 0; std::tie(is_int32, is_const_int32, value) = _.EvalInt32IfConst(id); if (!is_int32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Memory Scope to be a 32-bit int"; } @@ -81,11 +102,18 @@ spv_result_t ValidateMemoryScope(ValidationState_t& _, } if (spvIsVulkanEnv(_.context()->target_env)) { - if (value != SpvScopeDevice && value != SpvScopeWorkgroup && - value != SpvScopeInvocation) { - return _.diag(SPV_ERROR_INVALID_DATA) + if (value == SpvScopeCrossDevice) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) - << ": in Vulkan environment Memory Scope is limited to Device, " + << ": in Vulkan environment, Memory Scope cannot be CrossDevice"; + } + if (_.context()->target_env == SPV_ENV_VULKAN_1_0 && + value != SpvScopeDevice && value != SpvScopeWorkgroup && + value != SpvScopeInvocation) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": in Vulkan 1.0 environment Memory Scope is limited to " + "Device, " "Workgroup and Invocation"; } } @@ -97,15 +125,14 @@ spv_result_t ValidateMemoryScope(ValidationState_t& _, // Validates Memory Semantics operand. spv_result_t ValidateMemorySemantics(ValidationState_t& _, - const spv_parsed_instruction_t* inst, - uint32_t id) { - const SpvOp opcode = static_cast(inst->opcode); + const Instruction* inst, uint32_t id) { + const SpvOp opcode = inst->opcode(); bool is_int32 = false, is_const_int32 = false; uint32_t value = 0; std::tie(is_int32, is_const_int32, value) = _.EvalInt32IfConst(id); if (!is_int32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Memory Semantics to be a 32-bit int"; } @@ -114,13 +141,13 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _, return SPV_SUCCESS; } - const size_t num_memory_order_set_bits = spvutils::CountSetBits( + const size_t num_memory_order_set_bits = spvtools::utils::CountSetBits( value & (SpvMemorySemanticsAcquireMask | SpvMemorySemanticsReleaseMask | SpvMemorySemanticsAcquireReleaseMask | SpvMemorySemanticsSequentiallyConsistentMask)); if (num_memory_order_set_bits > 1) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": Memory Semantics can have at most one of the following bits " "set: Acquire, Release, AcquireRelease or SequentiallyConsistent"; @@ -133,7 +160,7 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _, SpvMemorySemanticsImageMemoryMask); if (opcode == SpvOpMemoryBarrier && !num_memory_order_set_bits) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": Vulkan specification requires Memory Semantics to have one " "of the following bits set: Acquire, Release, AcquireRelease " @@ -141,7 +168,7 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _, } if (opcode == SpvOpMemoryBarrier && !includes_storage_class) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Memory Semantics to include a Vulkan-supported " "storage class"; @@ -150,7 +177,7 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _, #if 0 // TODO(atgoo@github.com): this check fails Vulkan CTS, reenable once fixed. if (opcode == SpvOpControlBarrier && value && !includes_storage_class) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Memory Semantics to include a Vulkan-supported " "storage class if Memory Semantics is not None"; @@ -163,38 +190,38 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _, return SPV_SUCCESS; } -} // anonymous namespace +} // namespace // Validates correctness of barrier instructions. -spv_result_t BarriersPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - const uint32_t result_type = inst->type_id; +spv_result_t BarriersPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); switch (opcode) { case SpvOpControlBarrier: { if (spvVersionForTargetEnv(_.context()->target_env) < SPV_SPIRV_VERSION_WORD(1, 3)) { - _.current_function().RegisterExecutionModelLimitation( - [](SpvExecutionModel model, std::string* message) { - if (model != SpvExecutionModelTessellationControl && - model != SpvExecutionModelGLCompute && - model != SpvExecutionModelKernel) { - if (message) { - *message = - "OpControlBarrier requires one of the following " - "Execution " - "Models: TessellationControl, GLCompute or Kernel"; - } - return false; - } - return true; - }); + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + [](SpvExecutionModel model, std::string* message) { + if (model != SpvExecutionModelTessellationControl && + model != SpvExecutionModelGLCompute && + model != SpvExecutionModelKernel) { + if (message) { + *message = + "OpControlBarrier requires one of the following " + "Execution " + "Models: TessellationControl, GLCompute or Kernel"; + } + return false; + } + return true; + }); } - const uint32_t execution_scope = inst->words[1]; - const uint32_t memory_scope = inst->words[2]; - const uint32_t memory_semantics = inst->words[3]; + const uint32_t execution_scope = inst->word(1); + const uint32_t memory_scope = inst->word(2); + const uint32_t memory_semantics = inst->word(3); if (auto error = ValidateExecutionScope(_, inst, execution_scope)) { return error; @@ -211,8 +238,8 @@ spv_result_t BarriersPass(ValidationState_t& _, } case SpvOpMemoryBarrier: { - const uint32_t memory_scope = inst->words[1]; - const uint32_t memory_semantics = inst->words[2]; + const uint32_t memory_scope = inst->word(1); + const uint32_t memory_semantics = inst->word(2); if (auto error = ValidateMemoryScope(_, inst, memory_scope)) { return error; @@ -226,7 +253,7 @@ spv_result_t BarriersPass(ValidationState_t& _, case SpvOpNamedBarrierInitialize: { if (_.GetIdOpcode(result_type) != SpvOpTypeNamedBarrier) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Result Type to be OpTypeNamedBarrier"; } @@ -234,7 +261,7 @@ spv_result_t BarriersPass(ValidationState_t& _, const uint32_t subgroup_count_type = _.GetOperandTypeId(inst, 2); if (!_.IsIntScalarType(subgroup_count_type) || _.GetBitWidth(subgroup_count_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Subgroup Count to be a 32-bit int"; } @@ -244,13 +271,13 @@ spv_result_t BarriersPass(ValidationState_t& _, case SpvOpMemoryNamedBarrier: { const uint32_t named_barrier_type = _.GetOperandTypeId(inst, 0); if (_.GetIdOpcode(named_barrier_type) != SpvOpTypeNamedBarrier) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Named Barrier to be of type OpTypeNamedBarrier"; } - const uint32_t memory_scope = inst->words[2]; - const uint32_t memory_semantics = inst->words[3]; + const uint32_t memory_scope = inst->word(2); + const uint32_t memory_semantics = inst->word(3); if (auto error = ValidateMemoryScope(_, inst, memory_scope)) { return error; @@ -269,4 +296,5 @@ spv_result_t BarriersPass(ValidationState_t& _, return SPV_SUCCESS; } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_bitwise.cpp b/3rdparty/spirv-tools/source/val/validate_bitwise.cpp similarity index 66% rename from 3rdparty/spirv-tools/source/validate_bitwise.cpp rename to 3rdparty/spirv-tools/source/val/validate_bitwise.cpp index 94978d93f..d46b3fcab 100644 --- a/3rdparty/spirv-tools/source/validate_bitwise.cpp +++ b/3rdparty/spirv-tools/source/val/validate_bitwise.cpp @@ -14,79 +14,58 @@ // Validates correctness of bitwise instructions. -#include "validate.h" +#include "source/val/validate.h" -#include "diagnostic.h" -#include "opcode.h" -#include "val/instruction.h" -#include "val/validation_state.h" +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" -namespace libspirv { - -namespace { - -// Returns operand word for given instruction and operand index. -// The operand is expected to only have one word. -inline uint32_t GetOperandWord(const spv_parsed_instruction_t* inst, - size_t operand_index) { - assert(operand_index < inst->num_operands); - const spv_parsed_operand_t& operand = inst->operands[operand_index]; - assert(operand.num_words == 1); - return inst->words[operand.offset]; -} - -// Returns the type id of instruction operand at |operand_index|. -// The operand is expected to be an id. -inline uint32_t GetOperandTypeId(ValidationState_t& _, - const spv_parsed_instruction_t* inst, - size_t operand_index) { - return _.GetTypeId(GetOperandWord(inst, operand_index)); -} -} // namespace +namespace spvtools { +namespace val { // Validates correctness of bitwise instructions. -spv_result_t BitwisePass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - const uint32_t result_type = inst->type_id; +spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); switch (opcode) { case SpvOpShiftRightLogical: case SpvOpShiftRightArithmetic: case SpvOpShiftLeftLogical: { if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t result_dimension = _.GetDimension(result_type); - const uint32_t base_type = GetOperandTypeId(_, inst, 2); - const uint32_t shift_type = GetOperandTypeId(_, inst, 3); + const uint32_t base_type = _.GetOperandTypeId(inst, 2); + const uint32_t shift_type = _.GetOperandTypeId(inst, 3); if (!base_type || (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Base to be int scalar or vector: " << spvOpcodeString(opcode); if (_.GetDimension(base_type) != result_dimension) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Base to have the same dimension " << "as Result Type: " << spvOpcodeString(opcode); if (_.GetBitWidth(base_type) != _.GetBitWidth(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Base to have the same bit width " << "as Result Type: " << spvOpcodeString(opcode); if (!shift_type || (!_.IsIntScalarType(shift_type) && !_.IsIntVectorType(shift_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Shift to be int scalar or vector: " << spvOpcodeString(opcode); if (_.GetDimension(shift_type) != result_dimension) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Shift to have the same dimension " << "as Result Type: " << spvOpcodeString(opcode); break; @@ -97,31 +76,31 @@ spv_result_t BitwisePass(ValidationState_t& _, case SpvOpBitwiseAnd: case SpvOpNot: { if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t result_dimension = _.GetDimension(result_type); const uint32_t result_bit_width = _.GetBitWidth(result_type); - for (size_t operand_index = 2; operand_index < inst->num_operands; + for (size_t operand_index = 2; operand_index < inst->operands().size(); ++operand_index) { - const uint32_t type_id = GetOperandTypeId(_, inst, operand_index); + const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); if (!type_id || (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector as operand: " << spvOpcodeString(opcode) << " operand index " << operand_index; if (_.GetDimension(type_id) != result_dimension) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected operands to have the same dimension " << "as Result Type: " << spvOpcodeString(opcode) << " operand index " << operand_index; if (_.GetBitWidth(type_id) != result_bit_width) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected operands to have the same bit width " << "as Result Type: " << spvOpcodeString(opcode) << " operand index " << operand_index; @@ -131,32 +110,32 @@ spv_result_t BitwisePass(ValidationState_t& _, case SpvOpBitFieldInsert: { if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); - const uint32_t base_type = GetOperandTypeId(_, inst, 2); - const uint32_t insert_type = GetOperandTypeId(_, inst, 3); - const uint32_t offset_type = GetOperandTypeId(_, inst, 4); - const uint32_t count_type = GetOperandTypeId(_, inst, 5); + const uint32_t base_type = _.GetOperandTypeId(inst, 2); + const uint32_t insert_type = _.GetOperandTypeId(inst, 3); + const uint32_t offset_type = _.GetOperandTypeId(inst, 4); + const uint32_t count_type = _.GetOperandTypeId(inst, 5); if (base_type != result_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Base Type to be equal to Result Type: " << spvOpcodeString(opcode); if (insert_type != result_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Insert Type to be equal to Result Type: " << spvOpcodeString(opcode); if (!offset_type || !_.IsIntScalarType(offset_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Offset Type to be int scalar: " << spvOpcodeString(opcode); if (!count_type || !_.IsIntScalarType(count_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Count Type to be int scalar: " << spvOpcodeString(opcode); break; @@ -165,26 +144,26 @@ spv_result_t BitwisePass(ValidationState_t& _, case SpvOpBitFieldSExtract: case SpvOpBitFieldUExtract: { if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); - const uint32_t base_type = GetOperandTypeId(_, inst, 2); - const uint32_t offset_type = GetOperandTypeId(_, inst, 3); - const uint32_t count_type = GetOperandTypeId(_, inst, 4); + const uint32_t base_type = _.GetOperandTypeId(inst, 2); + const uint32_t offset_type = _.GetOperandTypeId(inst, 3); + const uint32_t count_type = _.GetOperandTypeId(inst, 4); if (base_type != result_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Base Type to be equal to Result Type: " << spvOpcodeString(opcode); if (!offset_type || !_.IsIntScalarType(offset_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Offset Type to be int scalar: " << spvOpcodeString(opcode); if (!count_type || !_.IsIntScalarType(count_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Count Type to be int scalar: " << spvOpcodeString(opcode); break; @@ -192,14 +171,14 @@ spv_result_t BitwisePass(ValidationState_t& _, case SpvOpBitReverse: { if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); - const uint32_t base_type = GetOperandTypeId(_, inst, 2); + const uint32_t base_type = _.GetOperandTypeId(inst, 2); if (base_type != result_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Base Type to be equal to Result Type: " << spvOpcodeString(opcode); break; @@ -207,14 +186,14 @@ spv_result_t BitwisePass(ValidationState_t& _, case SpvOpBitCount: { if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); - const uint32_t base_type = GetOperandTypeId(_, inst, 2); + const uint32_t base_type = _.GetOperandTypeId(inst, 2); if (!base_type || (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Base Type to be int scalar or vector: " << spvOpcodeString(opcode); @@ -222,7 +201,7 @@ spv_result_t BitwisePass(ValidationState_t& _, const uint32_t result_dimension = _.GetDimension(result_type); if (base_dimension != result_dimension) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Base dimension to be equal to Result Type " "dimension: " << spvOpcodeString(opcode); @@ -236,4 +215,5 @@ spv_result_t BitwisePass(ValidationState_t& _, return SPV_SUCCESS; } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_builtins.cpp b/3rdparty/spirv-tools/source/val/validate_builtins.cpp similarity index 81% rename from 3rdparty/spirv-tools/source/validate_builtins.cpp rename to 3rdparty/spirv-tools/source/val/validate_builtins.cpp index c494a888e..c791e428f 100644 --- a/3rdparty/spirv-tools/source/validate_builtins.cpp +++ b/3rdparty/spirv-tools/source/val/validate_builtins.cpp @@ -14,7 +14,7 @@ // Validates correctness of built-in variables. -#include "validate.h" +#include "source/val/validate.h" #include #include @@ -22,18 +22,19 @@ #include #include #include +#include #include #include -#include "diagnostic.h" -#include "opcode.h" -#include "spirv_target_env.h" -#include "util/bitutils.h" -#include "val/instruction.h" -#include "val/validation_state.h" - -namespace libspirv { +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/util/bitutils.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" +namespace spvtools { +namespace val { namespace { // Returns a short textual description of the id defined by the given @@ -73,7 +74,7 @@ spv_result_t GetUnderlyingType(const ValidationState_t& _, uint32_t storage_class = 0; if (!_.GetPointerTypeInfo(inst.type_id(), underlying_type, &storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << GetIdDesc(inst) << " is decorated with BuiltIn. BuiltIn decoration should only be " "applied to struct types, variables and constants."; @@ -319,15 +320,40 @@ class BuiltInsValidator { spv_result_t ValidateF32( const Decoration& decoration, const Instruction& inst, const std::function& diag); + spv_result_t ValidateOptionalArrayedF32( + const Decoration& decoration, const Instruction& inst, + const std::function& diag); + spv_result_t ValidateF32Helper( + const Decoration& decoration, const Instruction& inst, + const std::function& diag, + uint32_t underlying_type); spv_result_t ValidateF32Vec( const Decoration& decoration, const Instruction& inst, uint32_t num_components, const std::function& diag); + spv_result_t ValidateOptionalArrayedF32Vec( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag); + spv_result_t ValidateF32VecHelper( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag, + uint32_t underlying_type); // If |num_components| is zero, the number of components is not checked. spv_result_t ValidateF32Arr( const Decoration& decoration, const Instruction& inst, uint32_t num_components, const std::function& diag); + spv_result_t ValidateOptionalArrayedF32Arr( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag); + spv_result_t ValidateF32ArrHelper( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag, + uint32_t underlying_type); // Generates strings like "Member #0 of struct ID <2>". std::string GetDefinitionDesc(const Decoration& decoration, @@ -349,9 +375,6 @@ class BuiltInsValidator { // instruction. void Update(const Instruction& inst); - // Traverses call tree and computes function_to_entry_points_. - void ComputeFunctionToEntryPointMapping(); - const ValidationState_t& _; // Mapping id -> list of rules which validate instruction referencing the @@ -370,10 +393,6 @@ class BuiltInsValidator { const std::vector no_entry_points; const std::vector* entry_points_ = &no_entry_points; - // Mapping function -> array of entry points inside this - // module which can (indirectly) call the function. - std::unordered_map> function_to_entry_points_; - // Execution models with which the current function can be called. std::set execution_models_; }; @@ -385,17 +404,12 @@ void BuiltInsValidator::Update(const Instruction& inst) { assert(function_id_ == 0); function_id_ = inst.id(); execution_models_.clear(); - const auto it = function_to_entry_points_.find(function_id_); - if (it == function_to_entry_points_.end()) { - entry_points_ = &no_entry_points; - } else { - entry_points_ = &it->second; - // Collect execution models from all entry points from which the current - // function can be called. - for (const uint32_t entry_point : *entry_points_) { - if (const auto* models = _.GetExecutionModels(entry_point)) { - execution_models_.insert(models->begin(), models->end()); - } + entry_points_ = &_.FunctionEntryPoints(function_id_); + // Collect execution models from all entry points from which the current + // function can be called. + for (const uint32_t entry_point : *entry_points_) { + if (const auto* models = _.GetExecutionModels(entry_point)) { + execution_models_.insert(models->begin(), models->end()); } } } @@ -409,28 +423,6 @@ void BuiltInsValidator::Update(const Instruction& inst) { } } -void BuiltInsValidator::ComputeFunctionToEntryPointMapping() { - // TODO: Move this into validation_state.cpp. - for (const uint32_t entry_point : _.entry_points()) { - std::stack call_stack; - std::set visited; - call_stack.push(entry_point); - while (!call_stack.empty()) { - const uint32_t called_func_id = call_stack.top(); - call_stack.pop(); - if (!visited.insert(called_func_id).second) continue; - - function_to_entry_points_[called_func_id].push_back(entry_point); - - const Function* called_func = _.function(called_func_id); - assert(called_func); - for (const uint32_t new_call : called_func->function_call_targets()) { - call_stack.push(new_call); - } - } - } -} - std::string BuiltInsValidator::GetDefinitionDesc( const Decoration& decoration, const Instruction& inst) const { std::ostringstream ss; @@ -520,6 +512,23 @@ spv_result_t BuiltInsValidator::ValidateI32( return SPV_SUCCESS; } +spv_result_t BuiltInsValidator::ValidateOptionalArrayedF32( + const Decoration& decoration, const Instruction& inst, + const std::function& diag) { + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + // Strip the array, if present. + if (_.GetIdOpcode(underlying_type) == SpvOpTypeArray) { + underlying_type = _.FindDef(underlying_type)->word(2u); + } + + return ValidateF32Helper(decoration, inst, diag, underlying_type); +} + spv_result_t BuiltInsValidator::ValidateF32( const Decoration& decoration, const Instruction& inst, const std::function& diag) { @@ -529,6 +538,13 @@ spv_result_t BuiltInsValidator::ValidateF32( return error; } + return ValidateF32Helper(decoration, inst, diag, underlying_type); +} + +spv_result_t BuiltInsValidator::ValidateF32Helper( + const Decoration& decoration, const Instruction& inst, + const std::function& diag, + uint32_t underlying_type) { if (!_.IsFloatScalarType(underlying_type)) { return diag(GetDefinitionDesc(decoration, inst) + " is not a float scalar."); @@ -578,6 +594,25 @@ spv_result_t BuiltInsValidator::ValidateI32Vec( return SPV_SUCCESS; } +spv_result_t BuiltInsValidator::ValidateOptionalArrayedF32Vec( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag) { + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + // Strip the array, if present. + if (_.GetIdOpcode(underlying_type) == SpvOpTypeArray) { + underlying_type = _.FindDef(underlying_type)->word(2u); + } + + return ValidateF32VecHelper(decoration, inst, num_components, diag, + underlying_type); +} + spv_result_t BuiltInsValidator::ValidateF32Vec( const Decoration& decoration, const Instruction& inst, uint32_t num_components, @@ -588,6 +623,15 @@ spv_result_t BuiltInsValidator::ValidateF32Vec( return error; } + return ValidateF32VecHelper(decoration, inst, num_components, diag, + underlying_type); +} + +spv_result_t BuiltInsValidator::ValidateF32VecHelper( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag, + uint32_t underlying_type) { if (!_.IsFloatVectorType(underlying_type)) { return diag(GetDefinitionDesc(decoration, inst) + " is not a float vector."); @@ -653,6 +697,37 @@ spv_result_t BuiltInsValidator::ValidateF32Arr( return error; } + return ValidateF32ArrHelper(decoration, inst, num_components, diag, + underlying_type); +} + +spv_result_t BuiltInsValidator::ValidateOptionalArrayedF32Arr( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag) { + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + // Strip an extra layer of arraying if present. + if (_.GetIdOpcode(underlying_type) == SpvOpTypeArray) { + uint32_t subtype = _.FindDef(underlying_type)->word(2u); + if (_.GetIdOpcode(subtype) == SpvOpTypeArray) { + underlying_type = subtype; + } + } + + return ValidateF32ArrHelper(decoration, inst, num_components, diag, + underlying_type); +} + +spv_result_t BuiltInsValidator::ValidateF32ArrHelper( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag, + uint32_t underlying_type) { const Instruction* const type_inst = _.FindDef(underlying_type); if (type_inst->opcode() != SpvOpTypeArray) { return diag(GetDefinitionDesc(decoration, inst) + " is not an array."); @@ -699,7 +774,7 @@ spv_result_t BuiltInsValidator::ValidateNotCalledWithExecutionModel( SPV_OPERAND_TYPE_EXECUTION_MODEL, execution_model); const char* built_in_str = _.grammar().lookupOperandName( SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0]); - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << comment << " " << GetIdDesc(referenced_inst) << " depends on " << GetIdDesc(built_in_inst) << " which is decorated with BuiltIn " << built_in_str << "." @@ -720,21 +795,6 @@ spv_result_t BuiltInsValidator::ValidateNotCalledWithExecutionModel( spv_result_t BuiltInsValidator::ValidateClipOrCullDistanceAtDefinition( const Decoration& decoration, const Instruction& inst) { - if (spvIsVulkanEnv(_.context()->target_env)) { - if (spv_result_t error = ValidateF32Arr( - decoration, inst, /* Any number of components */ 0, - [this, &decoration](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) - << "According to the Vulkan spec BuiltIn " - << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, - decoration.params()[0]) - << " variable needs to be a 32-bit float array. " - << message; - })) { - return error; - } - } - // Seed at reference checks with this built-in. return ValidateClipOrCullDistanceAtReference(decoration, inst, inst, inst); } @@ -748,7 +808,7 @@ spv_result_t BuiltInsValidator::ValidateClipOrCullDistanceAtReference( if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput && storage_class != SpvStorageClassOutput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn " << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0]) @@ -784,16 +844,65 @@ spv_result_t BuiltInsValidator::ValidateClipOrCullDistanceAtReference( for (const SpvExecutionModel execution_model : execution_models_) { switch (execution_model) { case SpvExecutionModelFragment: - case SpvExecutionModelVertex: + case SpvExecutionModelVertex: { + if (spv_result_t error = ValidateF32Arr( + decoration, built_in_inst, /* Any number of components */ 0, + [this, &decoration, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "According to the Vulkan spec BuiltIn " + << _.grammar().lookupOperandName( + SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " variable needs to be a 32-bit float array. " + << message; + })) { + return error; + } + break; + } case SpvExecutionModelTessellationControl: case SpvExecutionModelTessellationEvaluation: case SpvExecutionModelGeometry: { - // Ok. + if (decoration.struct_member_index() != Decoration::kInvalidMember) { + // The outer level of array is applied on the variable. + if (spv_result_t error = ValidateF32Arr( + decoration, built_in_inst, /* Any number of components */ 0, + [this, &decoration, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, + &referenced_from_inst) + << "According to the Vulkan spec BuiltIn " + << _.grammar().lookupOperandName( + SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " variable needs to be a 32-bit float array. " + << message; + })) { + return error; + } + } else { + if (spv_result_t error = ValidateOptionalArrayedF32Arr( + decoration, built_in_inst, /* Any number of components */ 0, + [this, &decoration, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, + &referenced_from_inst) + << "According to the Vulkan spec BuiltIn " + << _.grammar().lookupOperandName( + SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " variable needs to be a 32-bit float array. " + << message; + })) { + return error; + } + } break; } default: { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn " << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0]) @@ -823,8 +932,8 @@ spv_result_t BuiltInsValidator::ValidateFragCoordAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateF32Vec( decoration, inst, 4, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn FragCoord " "variable needs to be a 4-component 32-bit float " "vector. " @@ -846,7 +955,7 @@ spv_result_t BuiltInsValidator::ValidateFragCoordAtReference( const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn FragCoord to be only used for " "variables with Input storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -856,7 +965,7 @@ spv_result_t BuiltInsValidator::ValidateFragCoordAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelFragment) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn FragCoord to be used only with " "Fragment execution model. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -880,8 +989,8 @@ spv_result_t BuiltInsValidator::ValidateFragDepthAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateF32( decoration, inst, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn FragDepth " "variable needs to be a 32-bit float scalar. " << message; @@ -902,7 +1011,7 @@ spv_result_t BuiltInsValidator::ValidateFragDepthAtReference( const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassOutput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn FragDepth to be only used for " "variables with Output storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -912,7 +1021,7 @@ spv_result_t BuiltInsValidator::ValidateFragDepthAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelFragment) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn FragDepth to be used only with " "Fragment execution model. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -925,7 +1034,7 @@ spv_result_t BuiltInsValidator::ValidateFragDepthAtReference( // Execution Mode DepthReplacing. const auto* modes = _.GetExecutionModes(entry_point); if (!modes || !modes->count(SpvExecutionModeDepthReplacing)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec requires DepthReplacing execution mode to be " "declared when using BuiltIn FragDepth. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -949,8 +1058,8 @@ spv_result_t BuiltInsValidator::ValidateFrontFacingAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateBool( decoration, inst, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn FrontFacing " "variable needs to be a bool scalar. " << message; @@ -971,7 +1080,7 @@ spv_result_t BuiltInsValidator::ValidateFrontFacingAtReference( const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn FrontFacing to be only used for " "variables with Input storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -981,7 +1090,7 @@ spv_result_t BuiltInsValidator::ValidateFrontFacingAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelFragment) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn FrontFacing to be used only with " "Fragment execution model. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1005,8 +1114,8 @@ spv_result_t BuiltInsValidator::ValidateHelperInvocationAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateBool( decoration, inst, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn HelperInvocation " "variable needs to be a bool scalar. " << message; @@ -1027,7 +1136,7 @@ spv_result_t BuiltInsValidator::ValidateHelperInvocationAtReference( const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn HelperInvocation to be only used " "for variables with Input storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1037,7 +1146,7 @@ spv_result_t BuiltInsValidator::ValidateHelperInvocationAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelFragment) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn HelperInvocation to be used only " "with Fragment execution model. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1062,8 +1171,8 @@ spv_result_t BuiltInsValidator::ValidateInvocationIdAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateI32( decoration, inst, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn InvocationId " "variable needs to be a 32-bit int scalar. " << message; @@ -1084,7 +1193,7 @@ spv_result_t BuiltInsValidator::ValidateInvocationIdAtReference( const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn InvocationId to be only used for " "variables with Input storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1095,7 +1204,7 @@ spv_result_t BuiltInsValidator::ValidateInvocationIdAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelTessellationControl && execution_model != SpvExecutionModelGeometry) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn InvocationId to be used only " "with TessellationControl or Geometry execution models. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1119,8 +1228,8 @@ spv_result_t BuiltInsValidator::ValidateInstanceIndexAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateI32( decoration, inst, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn InstanceIndex " "variable needs to be a 32-bit int scalar. " << message; @@ -1141,7 +1250,7 @@ spv_result_t BuiltInsValidator::ValidateInstanceIndexAtReference( const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn InstanceIndex to be only used for " "variables with Input storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1151,7 +1260,7 @@ spv_result_t BuiltInsValidator::ValidateInstanceIndexAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelVertex) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn InstanceIndex to be used only " "with Vertex execution model. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1175,8 +1284,8 @@ spv_result_t BuiltInsValidator::ValidatePatchVerticesAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateI32( decoration, inst, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn PatchVertices " "variable needs to be a 32-bit int scalar. " << message; @@ -1197,7 +1306,7 @@ spv_result_t BuiltInsValidator::ValidatePatchVerticesAtReference( const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn PatchVertices to be only used for " "variables with Input storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1208,7 +1317,7 @@ spv_result_t BuiltInsValidator::ValidatePatchVerticesAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelTessellationControl && execution_model != SpvExecutionModelTessellationEvaluation) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn PatchVertices to be used only " "with TessellationControl or TessellationEvaluation " "execution models. " @@ -1233,8 +1342,8 @@ spv_result_t BuiltInsValidator::ValidatePointCoordAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateF32Vec( decoration, inst, 2, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn PointCoord " "variable needs to be a 2-component 32-bit float " "vector. " @@ -1256,7 +1365,7 @@ spv_result_t BuiltInsValidator::ValidatePointCoordAtReference( const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn PointCoord to be only used for " "variables with Input storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1266,7 +1375,7 @@ spv_result_t BuiltInsValidator::ValidatePointCoordAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelFragment) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn PointCoord to be used only with " "Fragment execution model. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1287,19 +1396,6 @@ spv_result_t BuiltInsValidator::ValidatePointCoordAtReference( spv_result_t BuiltInsValidator::ValidatePointSizeAtDefinition( const Decoration& decoration, const Instruction& inst) { - if (spvIsVulkanEnv(_.context()->target_env)) { - if (spv_result_t error = ValidateF32( - decoration, inst, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) - << "According to the Vulkan spec BuiltIn PointSize " - "variable needs to be a 32-bit float scalar. " - << message; - })) { - return error; - } - } - // Seed at reference checks with this built-in. return ValidatePointSizeAtReference(decoration, inst, inst, inst); } @@ -1313,7 +1409,7 @@ spv_result_t BuiltInsValidator::ValidatePointSizeAtReference( if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput && storage_class != SpvStorageClassOutput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn PointSize to be only used for " "variables with Input or Output storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1333,16 +1429,61 @@ spv_result_t BuiltInsValidator::ValidatePointSizeAtReference( for (const SpvExecutionModel execution_model : execution_models_) { switch (execution_model) { - case SpvExecutionModelVertex: + case SpvExecutionModelVertex: { + if (spv_result_t error = ValidateF32( + decoration, built_in_inst, + [this, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "According to the Vulkan spec BuiltIn PointSize " + "variable needs to be a 32-bit float scalar. " + << message; + })) { + return error; + } + break; + } case SpvExecutionModelTessellationControl: case SpvExecutionModelTessellationEvaluation: case SpvExecutionModelGeometry: { - // Ok. + // PointSize can be a per-vertex variable for tessellation control, + // tessellation evaluation and geometry shader stages. In such cases + // variables will have an array of 32-bit floats. + if (decoration.struct_member_index() != Decoration::kInvalidMember) { + // The array is on the variable, so this must be a 32-bit float. + if (spv_result_t error = ValidateF32( + decoration, built_in_inst, + [this, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, + &referenced_from_inst) + << "According to the Vulkan spec BuiltIn " + "PointSize variable needs to be a 32-bit " + "float scalar. " + << message; + })) { + return error; + } + } else { + if (spv_result_t error = ValidateOptionalArrayedF32( + decoration, built_in_inst, + [this, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, + &referenced_from_inst) + << "According to the Vulkan spec BuiltIn " + "PointSize variable needs to be a 32-bit " + "float scalar. " + << message; + })) { + return error; + } + } break; } default: { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn PointSize to be used only with " "Vertex, TessellationControl, TessellationEvaluation or " "Geometry execution models. " @@ -1365,20 +1506,6 @@ spv_result_t BuiltInsValidator::ValidatePointSizeAtReference( spv_result_t BuiltInsValidator::ValidatePositionAtDefinition( const Decoration& decoration, const Instruction& inst) { - if (spvIsVulkanEnv(_.context()->target_env)) { - if (spv_result_t error = ValidateF32Vec( - decoration, inst, 4, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) - << "According to the Vulkan spec BuiltIn Position " - "variable needs to be a 4-component 32-bit float " - "vector. " - << message; - })) { - return error; - } - } - // Seed at reference checks with this built-in. return ValidatePositionAtReference(decoration, inst, inst, inst); } @@ -1392,7 +1519,7 @@ spv_result_t BuiltInsValidator::ValidatePositionAtReference( if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput && storage_class != SpvStorageClassOutput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn Position to be only used for " "variables with Input or Output storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1412,19 +1539,66 @@ spv_result_t BuiltInsValidator::ValidatePositionAtReference( for (const SpvExecutionModel execution_model : execution_models_) { switch (execution_model) { - case SpvExecutionModelVertex: + case SpvExecutionModelVertex: { + if (spv_result_t error = ValidateF32Vec( + decoration, built_in_inst, 4, + [this, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "According to the Vulkan spec BuiltIn Position " + "variable needs to be a 4-component 32-bit float " + "vector. " + << message; + })) { + return error; + } + break; + } + case SpvExecutionModelGeometry: case SpvExecutionModelTessellationControl: - case SpvExecutionModelTessellationEvaluation: - case SpvExecutionModelGeometry: { - // Ok. + case SpvExecutionModelTessellationEvaluation: { + // Position can be a per-vertex variable for tessellation control, + // tessellation evaluation and geometry shader stages. In such cases + // variables will have an array of 4-component 32-bit float vectors. + if (decoration.struct_member_index() != Decoration::kInvalidMember) { + // The array is on the variable, so this must be a 4-component + // 32-bit float vector. + if (spv_result_t error = ValidateF32Vec( + decoration, built_in_inst, 4, + [this, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, + &referenced_from_inst) + << "According to the Vulkan spec BuiltIn Position " + "variable needs to be a 4-component 32-bit " + "float vector. " + << message; + })) { + return error; + } + } else { + if (spv_result_t error = ValidateOptionalArrayedF32Vec( + decoration, built_in_inst, 4, + [this, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, + &referenced_from_inst) + << "According to the Vulkan spec BuiltIn Position " + "variable needs to be a 4-component 32-bit " + "float vector. " + << message; + })) { + return error; + } + } break; } default: { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Vulkan spec allows BuiltIn Position to be used only with " - "Vertex, TessellationControl, TessellationEvaluation or " - "Geometry execution models. " + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn Position to be used only " + "with Vertex, TessellationControl, TessellationEvaluation" + " or Geometry execution models. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, referenced_from_inst, execution_model); } @@ -1447,8 +1621,8 @@ spv_result_t BuiltInsValidator::ValidatePrimitiveIdAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateI32( decoration, inst, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn PrimitiveId " "variable needs to be a 32-bit int scalar. " << message; @@ -1470,7 +1644,7 @@ spv_result_t BuiltInsValidator::ValidatePrimitiveIdAtReference( if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput && storage_class != SpvStorageClassOutput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn PrimitiveId to be only used for " "variables with Input or Output storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1497,7 +1671,8 @@ spv_result_t BuiltInsValidator::ValidatePrimitiveIdAtReference( id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, "Vulkan spec doesn't allow BuiltIn PrimitiveId to be used for " - "variables with Output storage class if execution model is Fragment.", + "variables with Output storage class if execution model is " + "Fragment.", SpvExecutionModelFragment, decoration, built_in_inst, referenced_from_inst, std::placeholders::_1)); } @@ -1513,7 +1688,7 @@ spv_result_t BuiltInsValidator::ValidatePrimitiveIdAtReference( } default: { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn PrimitiveId to be used only " "with Fragment, TessellationControl, " "TessellationEvaluation or Geometry execution models. " @@ -1539,8 +1714,8 @@ spv_result_t BuiltInsValidator::ValidateSampleIdAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateI32( decoration, inst, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn SampleId " "variable needs to be a 32-bit int scalar. " << message; @@ -1561,7 +1736,7 @@ spv_result_t BuiltInsValidator::ValidateSampleIdAtReference( const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn SampleId to be only used for " "variables with Input storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1571,7 +1746,7 @@ spv_result_t BuiltInsValidator::ValidateSampleIdAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelFragment) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn SampleId to be used only with " "Fragment execution model. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1595,8 +1770,8 @@ spv_result_t BuiltInsValidator::ValidateSampleMaskAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateI32Arr( decoration, inst, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn SampleMask " "variable needs to be a 32-bit int array. " << message; @@ -1618,7 +1793,7 @@ spv_result_t BuiltInsValidator::ValidateSampleMaskAtReference( if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput && storage_class != SpvStorageClassOutput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn SampleMask to be only used for " "variables with Input or Output storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1628,8 +1803,9 @@ spv_result_t BuiltInsValidator::ValidateSampleMaskAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelFragment) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Vulkan spec allows BuiltIn SampleMask to be used only with " + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn SampleMask to be used only " + "with " "Fragment execution model. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, referenced_from_inst, execution_model); @@ -1652,8 +1828,8 @@ spv_result_t BuiltInsValidator::ValidateSamplePositionAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateF32Vec( decoration, inst, 2, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn SamplePosition " "variable needs to be a 2-component 32-bit float " "vector. " @@ -1675,8 +1851,9 @@ spv_result_t BuiltInsValidator::ValidateSamplePositionAtReference( const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Vulkan spec allows BuiltIn SamplePosition to be only used for " + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn SamplePosition to be only used " + "for " "variables with Input storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, referenced_from_inst) @@ -1685,7 +1862,7 @@ spv_result_t BuiltInsValidator::ValidateSamplePositionAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelFragment) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn SamplePosition to be used only " "with " "Fragment execution model. " @@ -1710,8 +1887,8 @@ spv_result_t BuiltInsValidator::ValidateTessCoordAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateF32Vec( decoration, inst, 3, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn TessCoord " "variable needs to be a 3-component 32-bit float " "vector. " @@ -1733,7 +1910,7 @@ spv_result_t BuiltInsValidator::ValidateTessCoordAtReference( const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn TessCoord to be only used for " "variables with Input storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1743,7 +1920,7 @@ spv_result_t BuiltInsValidator::ValidateTessCoordAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelTessellationEvaluation) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn TessCoord to be used only with " "TessellationEvaluation execution model. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1767,8 +1944,8 @@ spv_result_t BuiltInsValidator::ValidateTessLevelOuterAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateF32Arr( decoration, inst, 4, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn TessLevelOuter " "variable needs to be a 4-component 32-bit float " "array. " @@ -1787,8 +1964,8 @@ spv_result_t BuiltInsValidator::ValidateTessLevelInnerAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateF32Arr( decoration, inst, 2, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn TessLevelOuter " "variable needs to be a 2-component 32-bit float " "array. " @@ -1811,7 +1988,7 @@ spv_result_t BuiltInsValidator::ValidateTessLevelAtReference( if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput && storage_class != SpvStorageClassOutput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn " << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0]) @@ -1826,7 +2003,8 @@ spv_result_t BuiltInsValidator::ValidateTessLevelAtReference( assert(function_id_ == 0); id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, - "Vulkan spec doesn't allow TessLevelOuter/TessLevelInner to be used " + "Vulkan spec doesn't allow TessLevelOuter/TessLevelInner to be " + "used " "for variables with Input storage class if execution model is " "TessellationControl.", SpvExecutionModelTessellationControl, decoration, built_in_inst, @@ -1837,7 +2015,8 @@ spv_result_t BuiltInsValidator::ValidateTessLevelAtReference( assert(function_id_ == 0); id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, - "Vulkan spec doesn't allow TessLevelOuter/TessLevelInner to be used " + "Vulkan spec doesn't allow TessLevelOuter/TessLevelInner to be " + "used " "for variables with Output storage class if execution model is " "TessellationEvaluation.", SpvExecutionModelTessellationEvaluation, decoration, built_in_inst, @@ -1853,7 +2032,7 @@ spv_result_t BuiltInsValidator::ValidateTessLevelAtReference( } default: { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn " << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0]) @@ -1881,8 +2060,8 @@ spv_result_t BuiltInsValidator::ValidateVertexIndexAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateI32( decoration, inst, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn VertexIndex " "variable needs to be a 32-bit int scalar. " << message; @@ -1903,7 +2082,7 @@ spv_result_t BuiltInsValidator::ValidateVertexIndexAtReference( const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn VertexIndex to be only used for " "variables with Input storage class. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, @@ -1913,8 +2092,9 @@ spv_result_t BuiltInsValidator::ValidateVertexIndexAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelVertex) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Vulkan spec allows BuiltIn VertexIndex to be used only with " + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn VertexIndex to be used only " + "with " "Vertex execution model. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, referenced_from_inst, execution_model); @@ -1937,8 +2117,9 @@ spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateI32( decoration, inst, - [this, &decoration](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &decoration, + &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn " << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0]) @@ -1961,7 +2142,7 @@ spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtReference( if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput && storage_class != SpvStorageClassOutput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn " << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0]) @@ -1974,21 +2155,30 @@ spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtReference( if (storage_class == SpvStorageClassInput) { assert(function_id_ == 0); - id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( - &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, - "Vulkan spec doesn't allow BuiltIn Layer and ViewportIndex to be " - "used for variables with Input storage class if execution model is " - "Geometry.", - SpvExecutionModelGeometry, decoration, built_in_inst, - referenced_from_inst, std::placeholders::_1)); + for (const auto em : + {SpvExecutionModelVertex, SpvExecutionModelTessellationEvaluation, + SpvExecutionModelGeometry}) { + id_to_at_reference_checks_[referenced_from_inst.id()].push_back( + std::bind(&BuiltInsValidator::ValidateNotCalledWithExecutionModel, + this, + "Vulkan spec doesn't allow BuiltIn Layer and " + "ViewportIndex to be " + "used for variables with Input storage class if " + "execution model is Vertex, TessellationEvaluation, or " + "Geometry.", + em, decoration, built_in_inst, referenced_from_inst, + std::placeholders::_1)); + } } if (storage_class == SpvStorageClassOutput) { assert(function_id_ == 0); id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, - "Vulkan spec doesn't allow BuiltIn Layer and ViewportIndex to be " - "used for variables with Output storage class if execution model is " + "Vulkan spec doesn't allow BuiltIn Layer and " + "ViewportIndex to be " + "used for variables with Output storage class if " + "execution model is " "Fragment.", SpvExecutionModelFragment, decoration, built_in_inst, referenced_from_inst, std::placeholders::_1)); @@ -2000,15 +2190,25 @@ spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtReference( case SpvExecutionModelFragment: { // Ok. break; + case SpvExecutionModelVertex: + case SpvExecutionModelTessellationEvaluation: + if (!_.HasCapability(SpvCapabilityShaderViewportIndexLayerEXT)) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Using BuiltIn " + << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " in Vertex or Tessellation execution model requires " + "the ShaderViewportIndexLayerEXT capability."; + } + break; } - default: { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn " << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0]) - << " to be used only with Fragment or Geometry execution " - "models. " + << " to be used only with Vertex, TessellationEvaluation, " + "Geometry, or Fragment execution models. " << GetReferenceDesc(decoration, built_in_inst, referenced_inst, referenced_from_inst, execution_model); } @@ -2032,8 +2232,9 @@ spv_result_t BuiltInsValidator::ValidateComputeShaderI32Vec3InputAtDefinition( if (spvIsVulkanEnv(_.context()->target_env)) { if (spv_result_t error = ValidateI32Vec( decoration, inst, 3, - [this, &decoration](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &decoration, + &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn " << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0]) @@ -2058,7 +2259,7 @@ spv_result_t BuiltInsValidator::ValidateComputeShaderI32Vec3InputAtReference( const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); if (storage_class != SpvStorageClassMax && storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn " << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0]) @@ -2070,7 +2271,7 @@ spv_result_t BuiltInsValidator::ValidateComputeShaderI32Vec3InputAtReference( for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelGLCompute) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn " << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0]) @@ -2096,15 +2297,16 @@ spv_result_t BuiltInsValidator::ValidateWorkgroupSizeAtDefinition( const Decoration& decoration, const Instruction& inst) { if (spvIsVulkanEnv(_.context()->target_env)) { if (!spvOpcodeIsConstant(inst.opcode())) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Vulkan spec requires BuiltIn WorkgroupSize to be a constant. " + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "Vulkan spec requires BuiltIn WorkgroupSize to be a " + "constant. " << GetIdDesc(inst) << " is not a constant."; } if (spv_result_t error = ValidateI32Vec( decoration, inst, 3, - [this](const std::string& message) -> spv_result_t { - return _.diag(SPV_ERROR_INVALID_DATA) + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) << "According to the Vulkan spec BuiltIn WorkgroupSize " "variable " "needs to be a 3-component 32-bit int vector. " @@ -2123,22 +2325,9 @@ spv_result_t BuiltInsValidator::ValidateWorkgroupSizeAtReference( const Instruction& referenced_inst, const Instruction& referenced_from_inst) { if (spvIsVulkanEnv(_.context()->target_env)) { - const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); - if (storage_class != SpvStorageClassMax && - storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Vulkan spec allows BuiltIn " - << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, - decoration.params()[0]) - << " to be only used for variables with Input storage class. " - << GetReferenceDesc(decoration, built_in_inst, referenced_inst, - referenced_from_inst) - << " " << GetStorageClassDesc(referenced_from_inst); - } - for (const SpvExecutionModel execution_model : execution_models_) { if (execution_model != SpvExecutionModelGLCompute) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) << "Vulkan spec allows BuiltIn " << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0]) @@ -2321,8 +2510,6 @@ spv_result_t BuiltInsValidator::Run() { return SPV_SUCCESS; } - ComputeFunctionToEntryPointMapping(); - // Second pass: validate every id reference in the module using // rules in id_to_at_reference_checks_. for (const Instruction& inst : _.ordered_instructions()) { @@ -2347,8 +2534,8 @@ spv_result_t BuiltInsValidator::Run() { continue; } - // Instruction references the id. Run all checks associated with the id on - // the instruction. id_to_at_reference_checks_ can be modified in the + // Instruction references the id. Run all checks associated with the id + // on the instruction. id_to_at_reference_checks_ can be modified in the // process, iterators are safe because it's a tree-based map. const auto it = id_to_at_reference_checks_.find(id); if (it != id_to_at_reference_checks_.end()) { @@ -2364,7 +2551,7 @@ spv_result_t BuiltInsValidator::Run() { return SPV_SUCCESS; } -} // anonymous namespace +} // namespace // Validates correctness of built-in variables. spv_result_t ValidateBuiltIns(const ValidationState_t& _) { @@ -2383,4 +2570,5 @@ spv_result_t ValidateBuiltIns(const ValidationState_t& _) { return validator.Run(); } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_capability.cpp b/3rdparty/spirv-tools/source/val/validate_capability.cpp similarity index 76% rename from 3rdparty/spirv-tools/source/validate_capability.cpp rename to 3rdparty/spirv-tools/source/val/validate_capability.cpp index 1a47ebc5e..4724b9f79 100644 --- a/3rdparty/spirv-tools/source/validate_capability.cpp +++ b/3rdparty/spirv-tools/source/val/validate_capability.cpp @@ -14,18 +14,19 @@ // Validates OpCapability instruction. -#include "validate.h" +#include "source/val/validate.h" #include +#include #include -#include "diagnostic.h" -#include "opcode.h" -#include "val/instruction.h" -#include "val/validation_state.h" - -namespace libspirv { +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" +namespace spvtools { +namespace val { namespace { bool IsSupportGuaranteedVulkan_1_0(uint32_t capability) { @@ -44,6 +45,16 @@ bool IsSupportGuaranteedVulkan_1_0(uint32_t capability) { return false; } +bool IsSupportGuaranteedVulkan_1_1(uint32_t capability) { + if (IsSupportGuaranteedVulkan_1_0(capability)) return true; + switch (capability) { + case SpvCapabilityDeviceGroup: + case SpvCapabilityMultiView: + return true; + } + return false; +} + bool IsSupportOptionalVulkan_1_0(uint32_t capability) { switch (capability) { case SpvCapabilityGeometry: @@ -77,6 +88,34 @@ bool IsSupportOptionalVulkan_1_0(uint32_t capability) { return false; } +bool IsSupportOptionalVulkan_1_1(uint32_t capability) { + if (IsSupportOptionalVulkan_1_0(capability)) return true; + + switch (capability) { + case SpvCapabilityGroupNonUniform: + case SpvCapabilityGroupNonUniformVote: + case SpvCapabilityGroupNonUniformArithmetic: + case SpvCapabilityGroupNonUniformBallot: + case SpvCapabilityGroupNonUniformShuffle: + case SpvCapabilityGroupNonUniformShuffleRelative: + case SpvCapabilityGroupNonUniformClustered: + case SpvCapabilityGroupNonUniformQuad: + case SpvCapabilityDrawParameters: + // Alias SpvCapabilityStorageBuffer16BitAccess. + case SpvCapabilityStorageUniformBufferBlock16: + // Alias SpvCapabilityUniformAndStorageBuffer16BitAccess. + case SpvCapabilityStorageUniform16: + case SpvCapabilityStoragePushConstant16: + case SpvCapabilityStorageInputOutput16: + case SpvCapabilityDeviceGroup: + case SpvCapabilityMultiView: + case SpvCapabilityVariablePointersStorageBuffer: + case SpvCapabilityVariablePointers: + return true; + } + return false; +} + bool IsSupportGuaranteedOpenCL_1_2(uint32_t capability, bool embedded_profile) { switch (capability) { case SpvCapabilityAddresses: @@ -182,19 +221,17 @@ bool IsEnabledByCapabilityOpenCL_2_0(ValidationState_t& _, // Validates that capability declarations use operands allowed in the current // context. -spv_result_t CapabilityPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - if (opcode != SpvOpCapability) return SPV_SUCCESS; +spv_result_t CapabilityPass(ValidationState_t& _, const Instruction* inst) { + if (inst->opcode() != SpvOpCapability) return SPV_SUCCESS; - assert(inst->num_operands == 1); + assert(inst->operands().size() == 1); - const spv_parsed_operand_t& operand = inst->operands[0]; + const spv_parsed_operand_t& operand = inst->operand(0); assert(operand.num_words == 1); - assert(operand.offset < inst->num_words); + assert(operand.offset < inst->words().size()); - const uint32_t capability = inst->words[operand.offset]; + const uint32_t capability = inst->word(operand.offset); const auto capability_str = [&_, capability]() { spv_operand_desc desc = nullptr; if (_.grammar().lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, capability, @@ -215,17 +252,26 @@ spv_result_t CapabilityPass(ValidationState_t& _, if (!IsSupportGuaranteedVulkan_1_0(capability) && !IsSupportOptionalVulkan_1_0(capability) && !IsEnabledByExtension(_, capability)) { - return _.diag(SPV_ERROR_INVALID_CAPABILITY) + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) << "Capability " << capability_str() << " is not allowed by Vulkan 1.0 specification" << " (or requires extension)"; } + } else if (env == SPV_ENV_VULKAN_1_1) { + if (!IsSupportGuaranteedVulkan_1_1(capability) && + !IsSupportOptionalVulkan_1_1(capability) && + !IsEnabledByExtension(_, capability)) { + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) + << "Capability " << capability_str() + << " is not allowed by Vulkan 1.1 specification" + << " (or requires extension)"; + } } else if (env == SPV_ENV_OPENCL_1_2 || env == SPV_ENV_OPENCL_EMBEDDED_1_2) { if (!IsSupportGuaranteedOpenCL_1_2(capability, opencl_embedded) && !IsSupportOptionalOpenCL_1_2(capability) && !IsEnabledByExtension(_, capability) && !IsEnabledByCapabilityOpenCL_1_2(_, capability)) { - return _.diag(SPV_ERROR_INVALID_CAPABILITY) + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) << "Capability " << capability_str() << " is not allowed by OpenCL 1.2 " << opencl_profile << " Profile specification" @@ -237,7 +283,7 @@ spv_result_t CapabilityPass(ValidationState_t& _, !IsSupportOptionalOpenCL_1_2(capability) && !IsEnabledByExtension(_, capability) && !IsEnabledByCapabilityOpenCL_2_0(_, capability)) { - return _.diag(SPV_ERROR_INVALID_CAPABILITY) + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) << "Capability " << capability_str() << " is not allowed by OpenCL 2.0/2.1 " << opencl_profile << " Profile specification" @@ -248,7 +294,7 @@ spv_result_t CapabilityPass(ValidationState_t& _, !IsSupportOptionalOpenCL_1_2(capability) && !IsEnabledByExtension(_, capability) && !IsEnabledByCapabilityOpenCL_2_0(_, capability)) { - return _.diag(SPV_ERROR_INVALID_CAPABILITY) + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) << "Capability " << capability_str() << " is not allowed by OpenCL 2.2 " << opencl_profile << " Profile specification" @@ -259,4 +305,5 @@ spv_result_t CapabilityPass(ValidationState_t& _, return SPV_SUCCESS; } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_cfg.cpp b/3rdparty/spirv-tools/source/val/validate_cfg.cpp new file mode 100644 index 000000000..744641e97 --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_cfg.cpp @@ -0,0 +1,765 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/cfa.h" +#include "source/opcode.h" +#include "source/spirv_validator_options.h" +#include "source/val/basic_block.h" +#include "source/val/construct.h" +#include "source/val/function.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidatePhi(ValidationState_t& _, const Instruction* inst) { + auto block = inst->block(); + size_t num_in_ops = inst->words().size() - 3; + if (num_in_ops % 2 != 0) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpPhi does not have an equal number of incoming values and " + "basic blocks."; + } + + // Create a uniqued vector of predecessor ids for comparison against + // incoming values. OpBranchConditional %cond %label %label produces two + // predecessors in the CFG. + std::vector pred_ids; + std::transform(block->predecessors()->begin(), block->predecessors()->end(), + std::back_inserter(pred_ids), + [](const BasicBlock* b) { return b->id(); }); + std::sort(pred_ids.begin(), pred_ids.end()); + pred_ids.erase(std::unique(pred_ids.begin(), pred_ids.end()), pred_ids.end()); + + size_t num_edges = num_in_ops / 2; + if (num_edges != pred_ids.size()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpPhi's number of incoming blocks (" << num_edges + << ") does not match block's predecessor count (" + << block->predecessors()->size() << ")."; + } + + for (size_t i = 3; i < inst->words().size(); ++i) { + auto inc_id = inst->word(i); + if (i % 2 == 1) { + // Incoming value type must match the phi result type. + auto inc_type_id = _.GetTypeId(inc_id); + if (inst->type_id() != inc_type_id) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpPhi's result type " << _.getIdName(inst->type_id()) + << " does not match incoming value " << _.getIdName(inc_id) + << " type " << _.getIdName(inc_type_id) << "."; + } + } else { + if (_.GetIdOpcode(inc_id) != SpvOpLabel) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpPhi's incoming basic block " << _.getIdName(inc_id) + << " is not an OpLabel."; + } + + // Incoming basic block must be an immediate predecessor of the phi's + // block. + if (!std::binary_search(pred_ids.begin(), pred_ids.end(), inc_id)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpPhi's incoming basic block " << _.getIdName(inc_id) + << " is not a predecessor of " << _.getIdName(block->id()) + << "."; + } + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateBranchConditional(ValidationState_t& _, + const Instruction* inst) { + // num_operands is either 3 or 5 --- if 5, the last two need to be literal + // integers + const auto num_operands = inst->operands().size(); + if (num_operands != 3 && num_operands != 5) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpBranchConditional requires either 3 or 5 parameters"; + } + + // grab the condition operand and check that it is a bool + const auto cond_id = inst->GetOperandAs(0); + const auto cond_op = _.FindDef(cond_id); + if (!cond_op || !_.IsBoolScalarType(cond_op->type_id())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) << "Condition operand for " + "OpBranchConditional must be " + "of boolean type"; + } + + // target operands must be OpLabel + // note that we don't need to check that the target labels are in the same + // function, + // PerformCfgChecks already checks for that + const auto true_id = inst->GetOperandAs(1); + const auto true_target = _.FindDef(true_id); + if (!true_target || SpvOpLabel != true_target->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The 'True Label' operand for OpBranchConditional must be the " + "ID of an OpLabel instruction"; + } + + const auto false_id = inst->GetOperandAs(2); + const auto false_target = _.FindDef(false_id); + if (!false_target || SpvOpLabel != false_target->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The 'False Label' operand for OpBranchConditional must be the " + "ID of an OpLabel instruction"; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateReturnValue(ValidationState_t& _, + const Instruction* inst) { + const auto value_id = inst->GetOperandAs(0); + const auto value = _.FindDef(value_id); + if (!value || !value->type_id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpReturnValue Value '" << _.getIdName(value_id) + << "' does not represent a value."; + } + auto value_type = _.FindDef(value->type_id()); + if (!value_type || SpvOpTypeVoid == value_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpReturnValue value's type '" + << _.getIdName(value->type_id()) << "' is missing or void."; + } + + const bool uses_variable_pointer = + _.features().variable_pointers || + _.features().variable_pointers_storage_buffer; + + if (_.addressing_model() == SpvAddressingModelLogical && + SpvOpTypePointer == value_type->opcode() && !uses_variable_pointer && + !_.options()->relax_logical_pointer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpReturnValue value's type '" + << _.getIdName(value->type_id()) + << "' is a pointer, which is invalid in the Logical addressing " + "model."; + } + + const auto function = inst->function(); + const auto return_type = _.FindDef(function->GetResultTypeId()); + if (!return_type || return_type->id() != value_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpReturnValue Value '" << _.getIdName(value_id) + << "'s type does not match OpFunction's return type."; + } + + return SPV_SUCCESS; +} + +} // namespace + +void printDominatorList(const BasicBlock& b) { + std::cout << b.id() << " is dominated by: "; + const BasicBlock* bb = &b; + while (bb->immediate_dominator() != bb) { + bb = bb->immediate_dominator(); + std::cout << bb->id() << " "; + } +} + +#define CFG_ASSERT(ASSERT_FUNC, TARGET) \ + if (spv_result_t rcode = ASSERT_FUNC(_, TARGET)) return rcode + +spv_result_t FirstBlockAssert(ValidationState_t& _, uint32_t target) { + if (_.current_function().IsFirstBlock(target)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(_.current_function().id())) + << "First block " << _.getIdName(target) << " of function " + << _.getIdName(_.current_function().id()) << " is targeted by block " + << _.getIdName(_.current_function().current_block()->id()); + } + return SPV_SUCCESS; +} + +spv_result_t MergeBlockAssert(ValidationState_t& _, uint32_t merge_block) { + if (_.current_function().IsBlockType(merge_block, kBlockTypeMerge)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(_.current_function().id())) + << "Block " << _.getIdName(merge_block) + << " is already a merge block for another header"; + } + return SPV_SUCCESS; +} + +/// Update the continue construct's exit blocks once the backedge blocks are +/// identified in the CFG. +void UpdateContinueConstructExitBlocks( + Function& function, + const std::vector>& back_edges) { + auto& constructs = function.constructs(); + // TODO(umar): Think of a faster way to do this + for (auto& edge : back_edges) { + uint32_t back_edge_block_id; + uint32_t loop_header_block_id; + std::tie(back_edge_block_id, loop_header_block_id) = edge; + auto is_this_header = [=](Construct& c) { + return c.type() == ConstructType::kLoop && + c.entry_block()->id() == loop_header_block_id; + }; + + for (auto construct : constructs) { + if (is_this_header(construct)) { + Construct* continue_construct = + construct.corresponding_constructs().back(); + assert(continue_construct->type() == ConstructType::kContinue); + + BasicBlock* back_edge_block; + std::tie(back_edge_block, std::ignore) = + function.GetBlock(back_edge_block_id); + continue_construct->set_exit(back_edge_block); + } + } + } +} + +std::tuple ConstructNames( + ConstructType type) { + std::string construct_name, header_name, exit_name; + + switch (type) { + case ConstructType::kSelection: + construct_name = "selection"; + header_name = "selection header"; + exit_name = "merge block"; + break; + case ConstructType::kLoop: + construct_name = "loop"; + header_name = "loop header"; + exit_name = "merge block"; + break; + case ConstructType::kContinue: + construct_name = "continue"; + header_name = "continue target"; + exit_name = "back-edge block"; + break; + case ConstructType::kCase: + construct_name = "case"; + header_name = "case entry block"; + exit_name = "case exit block"; + break; + default: + assert(1 == 0 && "Not defined type"); + } + + return std::make_tuple(construct_name, header_name, exit_name); +} + +/// Constructs an error message for construct validation errors +std::string ConstructErrorString(const Construct& construct, + const std::string& header_string, + const std::string& exit_string, + const std::string& dominate_text) { + std::string construct_name, header_name, exit_name; + std::tie(construct_name, header_name, exit_name) = + ConstructNames(construct.type()); + + // TODO(umar): Add header block for continue constructs to error message + return "The " + construct_name + " construct with the " + header_name + " " + + header_string + " " + dominate_text + " the " + exit_name + " " + + exit_string; +} + +// Finds the fall through case construct of |target_block| and records it in +// |case_fall_through|. Returns SPV_ERROR_INVALID_CFG if the case construct +// headed by |target_block| branches to multiple case constructs. +spv_result_t FindCaseFallThrough( + const ValidationState_t& _, BasicBlock* target_block, + uint32_t* case_fall_through, const BasicBlock* merge, + const std::unordered_set& case_targets, Function* function) { + std::vector stack; + stack.push_back(target_block); + std::unordered_set visited; + bool target_reachable = target_block->reachable(); + int target_depth = function->GetBlockDepth(target_block); + while (!stack.empty()) { + auto block = stack.back(); + stack.pop_back(); + + if (block == merge) continue; + + if (!visited.insert(block).second) continue; + + if (target_reachable && block->reachable() && + target_block->dominates(*block)) { + // Still in the case construct. + for (auto successor : *block->successors()) { + stack.push_back(successor); + } + } else { + // Exiting the case construct to non-merge block. + if (!case_targets.count(block->id())) { + int depth = function->GetBlockDepth(block); + if ((depth < target_depth) || + (depth == target_depth && block->is_type(kBlockTypeContinue))) { + continue; + } + + return _.diag(SPV_ERROR_INVALID_CFG, target_block->label()) + << "Case construct that targets " + << _.getIdName(target_block->id()) + << " has invalid branch to block " << _.getIdName(block->id()) + << " (not another case construct, corresponding merge, outer " + "loop merge or outer loop continue)"; + } + + if (*case_fall_through == 0u) { + *case_fall_through = block->id(); + } else if (*case_fall_through != block->id()) { + // Case construct has at most one branch to another case construct. + return _.diag(SPV_ERROR_INVALID_CFG, target_block->label()) + << "Case construct that targets " + << _.getIdName(target_block->id()) + << " has branches to multiple other case construct targets " + << _.getIdName(*case_fall_through) << " and " + << _.getIdName(block->id()); + } + } + } + + return SPV_SUCCESS; +} + +spv_result_t StructuredSwitchChecks(const ValidationState_t& _, + Function* function, + const Instruction* switch_inst, + const BasicBlock* header, + const BasicBlock* merge) { + std::unordered_set case_targets; + for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) { + uint32_t target = switch_inst->GetOperandAs(i); + if (target != merge->id()) case_targets.insert(target); + } + // Tracks how many times each case construct is targeted by another case + // construct. + std::map num_fall_through_targeted; + uint32_t default_case_fall_through = 0u; + uint32_t default_target = switch_inst->GetOperandAs(1u); + std::unordered_set seen; + for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) { + uint32_t target = switch_inst->GetOperandAs(i); + if (target == merge->id()) continue; + + if (!seen.insert(target).second) continue; + + const auto target_block = function->GetBlock(target).first; + // OpSwitch must dominate all its case constructs. + if (header->reachable() && target_block->reachable() && + !header->dominates(*target_block)) { + return _.diag(SPV_ERROR_INVALID_CFG, header->label()) + << "Selection header " << _.getIdName(header->id()) + << " does not dominate its case construct " << _.getIdName(target); + } + + uint32_t case_fall_through = 0u; + if (auto error = FindCaseFallThrough(_, target_block, &case_fall_through, + merge, case_targets, function)) { + return error; + } + + // Track how many time the fall through case has been targeted. + if (case_fall_through != 0u) { + auto where = num_fall_through_targeted.lower_bound(case_fall_through); + if (where == num_fall_through_targeted.end() || + where->first != case_fall_through) { + num_fall_through_targeted.insert(where, + std::make_pair(case_fall_through, 1)); + } else { + where->second++; + } + } + + if (case_fall_through == default_target) { + case_fall_through = default_case_fall_through; + } + if (case_fall_through != 0u) { + bool is_default = i == 1; + if (is_default) { + default_case_fall_through = case_fall_through; + } else { + // Allow code like: + // case x: + // case y: + // ... + // case z: + // + // Where x and y target the same block and fall through to z. + uint32_t j = i; + while ((j + 2 < switch_inst->operands().size()) && + target == switch_inst->GetOperandAs(j + 2)) { + j += 2; + } + // If Target T1 branches to Target T2, or if Target T1 branches to the + // Default target and the Default target branches to Target T2, then T1 + // must immediately precede T2 in the list of OpSwitch Target operands. + if ((switch_inst->operands().size() < j + 2) || + (case_fall_through != switch_inst->GetOperandAs(j + 2))) { + return _.diag(SPV_ERROR_INVALID_CFG, switch_inst) + << "Case construct that targets " << _.getIdName(target) + << " has branches to the case construct that targets " + << _.getIdName(case_fall_through) + << ", but does not immediately precede it in the " + "OpSwitch's target list"; + } + } + } + } + + // Each case construct must be branched to by at most one other case + // construct. + for (const auto& pair : num_fall_through_targeted) { + if (pair.second > 1) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(pair.first)) + << "Multiple case constructs have branches to the case construct " + "that targets " + << _.getIdName(pair.first); + } + } + + return SPV_SUCCESS; +} + +spv_result_t StructuredControlFlowChecks( + const ValidationState_t& _, Function* function, + const std::vector>& back_edges) { + /// Check all backedges target only loop headers and have exactly one + /// back-edge branching to it + + // Map a loop header to blocks with back-edges to the loop header. + std::map> loop_latch_blocks; + for (auto back_edge : back_edges) { + uint32_t back_edge_block; + uint32_t header_block; + std::tie(back_edge_block, header_block) = back_edge; + if (!function->IsBlockType(header_block, kBlockTypeLoop)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(back_edge_block)) + << "Back-edges (" << _.getIdName(back_edge_block) << " -> " + << _.getIdName(header_block) + << ") can only be formed between a block and a loop header."; + } + loop_latch_blocks[header_block].insert(back_edge_block); + } + + // Check the loop headers have exactly one back-edge branching to it + for (BasicBlock* loop_header : function->ordered_blocks()) { + if (!loop_header->reachable()) continue; + if (!loop_header->is_type(kBlockTypeLoop)) continue; + auto loop_header_id = loop_header->id(); + auto num_latch_blocks = loop_latch_blocks[loop_header_id].size(); + if (num_latch_blocks != 1) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(loop_header_id)) + << "Loop header " << _.getIdName(loop_header_id) + << " is targeted by " << num_latch_blocks + << " back-edge blocks but the standard requires exactly one"; + } + } + + // Check construct rules + for (const Construct& construct : function->constructs()) { + auto header = construct.entry_block(); + auto merge = construct.exit_block(); + + if (header->reachable() && !merge) { + std::string construct_name, header_name, exit_name; + std::tie(construct_name, header_name, exit_name) = + ConstructNames(construct.type()); + return _.diag(SPV_ERROR_INTERNAL, _.FindDef(header->id())) + << "Construct " + construct_name + " with " + header_name + " " + + _.getIdName(header->id()) + " does not have a " + + exit_name + ". This may be a bug in the validator."; + } + + // If the exit block is reachable then it's dominated by the + // header. + if (merge && merge->reachable()) { + if (!header->dominates(*merge)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(merge->id())) + << ConstructErrorString(construct, _.getIdName(header->id()), + _.getIdName(merge->id()), + "does not dominate"); + } + // If it's really a merge block for a selection or loop, then it must be + // *strictly* dominated by the header. + if (construct.ExitBlockIsMergeBlock() && (header == merge)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(merge->id())) + << ConstructErrorString(construct, _.getIdName(header->id()), + _.getIdName(merge->id()), + "does not strictly dominate"); + } + } + // Check post-dominance for continue constructs. But dominance and + // post-dominance only make sense when the construct is reachable. + if (header->reachable() && construct.type() == ConstructType::kContinue) { + if (!merge->postdominates(*header)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(merge->id())) + << ConstructErrorString(construct, _.getIdName(header->id()), + _.getIdName(merge->id()), + "is not post dominated by"); + } + } + + // Check that for all non-header blocks, all predecessors are within this + // construct. + Construct::ConstructBlockSet construct_blocks = construct.blocks(function); + for (auto block : construct_blocks) { + if (block == header) continue; + for (auto pred : *block->predecessors()) { + if (pred->reachable() && !construct_blocks.count(pred)) { + std::string construct_name, header_name, exit_name; + std::tie(construct_name, header_name, exit_name) = + ConstructNames(construct.type()); + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(pred->id())) + << "block " << pred->id() << " branches to the " + << construct_name << " construct, but not to the " + << header_name << " " << header->id(); + } + } + } + + // Checks rules for case constructs. + if (construct.type() == ConstructType::kSelection && + header->terminator()->opcode() == SpvOpSwitch) { + const auto terminator = header->terminator(); + if (auto error = + StructuredSwitchChecks(_, function, terminator, header, merge)) { + return error; + } + } + } + return SPV_SUCCESS; +} + +spv_result_t PerformCfgChecks(ValidationState_t& _) { + for (auto& function : _.functions()) { + // Check all referenced blocks are defined within a function + if (function.undefined_block_count() != 0) { + std::string undef_blocks("{"); + bool first = true; + for (auto undefined_block : function.undefined_blocks()) { + undef_blocks += _.getIdName(undefined_block); + if (!first) { + undef_blocks += " "; + } + first = false; + } + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(function.id())) + << "Block(s) " << undef_blocks << "}" + << " are referenced but not defined in function " + << _.getIdName(function.id()); + } + + // Set each block's immediate dominator and immediate postdominator, + // and find all back-edges. + // + // We want to analyze all the blocks in the function, even in degenerate + // control flow cases including unreachable blocks. So use the augmented + // CFG to ensure we cover all the blocks. + std::vector postorder; + std::vector postdom_postorder; + std::vector> back_edges; + auto ignore_block = [](const BasicBlock*) {}; + auto ignore_edge = [](const BasicBlock*, const BasicBlock*) {}; + if (!function.ordered_blocks().empty()) { + /// calculate dominators + CFA::DepthFirstTraversal( + function.first_block(), function.AugmentedCFGSuccessorsFunction(), + ignore_block, [&](const BasicBlock* b) { postorder.push_back(b); }, + ignore_edge); + auto edges = CFA::CalculateDominators( + postorder, function.AugmentedCFGPredecessorsFunction()); + for (auto edge : edges) { + edge.first->SetImmediateDominator(edge.second); + } + + /// calculate post dominators + CFA::DepthFirstTraversal( + function.pseudo_exit_block(), + function.AugmentedCFGPredecessorsFunction(), ignore_block, + [&](const BasicBlock* b) { postdom_postorder.push_back(b); }, + ignore_edge); + auto postdom_edges = CFA::CalculateDominators( + postdom_postorder, function.AugmentedCFGSuccessorsFunction()); + for (auto edge : postdom_edges) { + edge.first->SetImmediatePostDominator(edge.second); + } + /// calculate back edges. + CFA::DepthFirstTraversal( + function.pseudo_entry_block(), + function + .AugmentedCFGSuccessorsFunctionIncludingHeaderToContinueEdge(), + ignore_block, ignore_block, + [&](const BasicBlock* from, const BasicBlock* to) { + back_edges.emplace_back(from->id(), to->id()); + }); + } + UpdateContinueConstructExitBlocks(function, back_edges); + + auto& blocks = function.ordered_blocks(); + if (!blocks.empty()) { + // Check if the order of blocks in the binary appear before the blocks + // they dominate + for (auto block = begin(blocks) + 1; block != end(blocks); ++block) { + if (auto idom = (*block)->immediate_dominator()) { + if (idom != function.pseudo_entry_block() && + block == std::find(begin(blocks), block, idom)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(idom->id())) + << "Block " << _.getIdName((*block)->id()) + << " appears in the binary before its dominator " + << _.getIdName(idom->id()); + } + } + } + // If we have structed control flow, check that no block has a control + // flow nesting depth larger than the limit. + if (_.HasCapability(SpvCapabilityShader)) { + const int control_flow_nesting_depth_limit = + _.options()->universal_limits_.max_control_flow_nesting_depth; + for (auto block = begin(blocks); block != end(blocks); ++block) { + if (function.GetBlockDepth(*block) > + control_flow_nesting_depth_limit) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef((*block)->id())) + << "Maximum Control Flow nesting depth exceeded."; + } + } + } + } + + /// Structured control flow checks are only required for shader capabilities + if (_.HasCapability(SpvCapabilityShader)) { + if (auto error = StructuredControlFlowChecks(_, &function, back_edges)) + return error; + } + } + return SPV_SUCCESS; +} + +spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) { + SpvOp opcode = inst->opcode(); + switch (opcode) { + case SpvOpLabel: + if (auto error = _.current_function().RegisterBlock(inst->id())) + return error; + + // TODO(github:1661) This should be done in the + // ValidationState::RegisterInstruction method but because of the order of + // passes the OpLabel ends up not being part of the basic block it starts. + _.current_function().current_block()->set_label(inst); + break; + case SpvOpLoopMerge: { + uint32_t merge_block = inst->GetOperandAs(0); + uint32_t continue_block = inst->GetOperandAs(1); + CFG_ASSERT(MergeBlockAssert, merge_block); + + if (auto error = _.current_function().RegisterLoopMerge(merge_block, + continue_block)) + return error; + } break; + case SpvOpSelectionMerge: { + uint32_t merge_block = inst->GetOperandAs(0); + CFG_ASSERT(MergeBlockAssert, merge_block); + + if (auto error = _.current_function().RegisterSelectionMerge(merge_block)) + return error; + } break; + case SpvOpBranch: { + uint32_t target = inst->GetOperandAs(0); + CFG_ASSERT(FirstBlockAssert, target); + + _.current_function().RegisterBlockEnd({target}, opcode); + } break; + case SpvOpBranchConditional: { + uint32_t tlabel = inst->GetOperandAs(1); + uint32_t flabel = inst->GetOperandAs(2); + CFG_ASSERT(FirstBlockAssert, tlabel); + CFG_ASSERT(FirstBlockAssert, flabel); + + _.current_function().RegisterBlockEnd({tlabel, flabel}, opcode); + } break; + + case SpvOpSwitch: { + std::vector cases; + for (size_t i = 1; i < inst->operands().size(); i += 2) { + uint32_t target = inst->GetOperandAs(i); + CFG_ASSERT(FirstBlockAssert, target); + cases.push_back(target); + } + _.current_function().RegisterBlockEnd({cases}, opcode); + } break; + case SpvOpReturn: { + const uint32_t return_type = _.current_function().GetResultTypeId(); + const Instruction* return_type_inst = _.FindDef(return_type); + assert(return_type_inst); + if (return_type_inst->opcode() != SpvOpTypeVoid) + return _.diag(SPV_ERROR_INVALID_CFG, inst) + << "OpReturn can only be called from a function with void " + << "return type."; + } + // Fallthrough. + case SpvOpKill: + case SpvOpReturnValue: + case SpvOpUnreachable: + _.current_function().RegisterBlockEnd(std::vector(), opcode); + if (opcode == SpvOpKill) { + _.current_function().RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + "OpKill requires Fragment execution model"); + } + break; + default: + break; + } + return SPV_SUCCESS; +} + +spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpPhi: + if (auto error = ValidatePhi(_, inst)) return error; + break; + case SpvOpBranchConditional: + if (auto error = ValidateBranchConditional(_, inst)) return error; + break; + case SpvOpReturnValue: + if (auto error = ValidateReturnValue(_, inst)) return error; + break; + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_composites.cpp b/3rdparty/spirv-tools/source/val/validate_composites.cpp new file mode 100644 index 000000000..6be60261e --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_composites.cpp @@ -0,0 +1,512 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of composite SPIR-V instructions. + +#include "source/val/validate.h" + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// Returns the type of the value accessed by OpCompositeExtract or +// OpCompositeInsert instruction. The function traverses the hierarchy of +// nested data structures (structs, arrays, vectors, matrices) as directed by +// the sequence of indices in the instruction. May return error if traversal +// fails (encountered non-composite, out of bounds, nesting too deep). +// Returns the type of Composite operand if the instruction has no indices. +spv_result_t GetExtractInsertValueType(ValidationState_t& _, + const Instruction* inst, + uint32_t* member_type) { + const SpvOp opcode = inst->opcode(); + assert(opcode == SpvOpCompositeExtract || opcode == SpvOpCompositeInsert); + uint32_t word_index = opcode == SpvOpCompositeExtract ? 4 : 5; + const uint32_t num_words = static_cast(inst->words().size()); + const uint32_t composite_id_index = word_index - 1; + + const uint32_t num_indices = num_words - word_index; + const uint32_t kCompositeExtractInsertMaxNumIndices = 255; + if (num_indices > kCompositeExtractInsertMaxNumIndices) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The number of indexes in Op" << spvOpcodeString(opcode) + << " may not exceed " << kCompositeExtractInsertMaxNumIndices + << ". Found " << num_indices << " indexes."; + } + + *member_type = _.GetTypeId(inst->word(composite_id_index)); + if (*member_type == 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Composite to be an object of composite type"; + } + + for (; word_index < num_words; ++word_index) { + const uint32_t component_index = inst->word(word_index); + const Instruction* const type_inst = _.FindDef(*member_type); + assert(type_inst); + switch (type_inst->opcode()) { + case SpvOpTypeVector: { + *member_type = type_inst->word(2); + const uint32_t vector_size = type_inst->word(3); + if (component_index >= vector_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Vector access is out of bounds, vector size is " + << vector_size << ", but access index is " << component_index; + } + break; + } + case SpvOpTypeMatrix: { + *member_type = type_inst->word(2); + const uint32_t num_cols = type_inst->word(3); + if (component_index >= num_cols) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Matrix access is out of bounds, matrix has " << num_cols + << " columns, but access index is " << component_index; + } + break; + } + case SpvOpTypeArray: { + uint64_t array_size = 0; + auto size = _.FindDef(type_inst->word(3)); + *member_type = type_inst->word(2); + if (spvOpcodeIsSpecConstant(size->opcode())) { + // Cannot verify against the size of this array. + break; + } + + if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) { + assert(0 && "Array type definition is corrupt"); + } + if (component_index >= array_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Array access is out of bounds, array size is " + << array_size << ", but access index is " << component_index; + } + break; + } + case SpvOpTypeRuntimeArray: { + *member_type = type_inst->word(2); + // Array size is unknown. + break; + } + case SpvOpTypeStruct: { + const size_t num_struct_members = type_inst->words().size() - 2; + if (component_index >= num_struct_members) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Index is out of bounds, can not find index " + << component_index << " in the structure '" + << type_inst->id() << "'. This structure has " + << num_struct_members << " members. Largest valid index is " + << num_struct_members - 1 << "."; + } + *member_type = type_inst->word(component_index + 2); + break; + } + default: + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Reached non-composite type while indexes still remain to " + "be traversed."; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateVectorExtractDynamic(ValidationState_t& _, + const Instruction* inst) { + const uint32_t result_type = inst->type_id(); + const SpvOp result_opcode = _.GetIdOpcode(result_type); + if (!spvOpcodeIsScalarType(result_opcode)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be a scalar type"; + } + + const uint32_t vector_type = _.GetOperandTypeId(inst, 2); + const SpvOp vector_opcode = _.GetIdOpcode(vector_type); + if (vector_opcode != SpvOpTypeVector) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Vector type to be OpTypeVector"; + } + + if (_.GetComponentType(vector_type) != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Vector component type to be equal to Result Type"; + } + + const uint32_t index_type = _.GetOperandTypeId(inst, 3); + if (!_.IsIntScalarType(index_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Index to be int scalar"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateVectorInsertDyanmic(ValidationState_t& _, + const Instruction* inst) { + const uint32_t result_type = inst->type_id(); + const SpvOp result_opcode = _.GetIdOpcode(result_type); + if (result_opcode != SpvOpTypeVector) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be OpTypeVector"; + } + + const uint32_t vector_type = _.GetOperandTypeId(inst, 2); + if (vector_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Vector type to be equal to Result Type"; + } + + const uint32_t component_type = _.GetOperandTypeId(inst, 3); + if (_.GetComponentType(result_type) != component_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Component type to be equal to Result Type " + << "component type"; + } + + const uint32_t index_type = _.GetOperandTypeId(inst, 4); + if (!_.IsIntScalarType(index_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Index to be int scalar"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateCompositeConstruct(ValidationState_t& _, + const Instruction* inst) { + const uint32_t num_operands = static_cast(inst->operands().size()); + const uint32_t result_type = inst->type_id(); + const SpvOp result_opcode = _.GetIdOpcode(result_type); + switch (result_opcode) { + case SpvOpTypeVector: { + const uint32_t num_result_components = _.GetDimension(result_type); + const uint32_t result_component_type = _.GetComponentType(result_type); + uint32_t given_component_count = 0; + + if (num_operands <= 3) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected number of constituents to be at least 2"; + } + + for (uint32_t operand_index = 2; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + if (operand_type == result_component_type) { + ++given_component_count; + } else { + if (_.GetIdOpcode(operand_type) != SpvOpTypeVector || + _.GetComponentType(operand_type) != result_component_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Constituents to be scalars or vectors of" + << " the same type as Result Type components"; + } + + given_component_count += _.GetDimension(operand_type); + } + } + + if (num_result_components != given_component_count) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected total number of given components to be equal " + << "to the size of Result Type vector"; + } + + break; + } + case SpvOpTypeMatrix: { + uint32_t result_num_rows = 0; + uint32_t result_num_cols = 0; + uint32_t result_col_type = 0; + uint32_t result_component_type = 0; + if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, &result_num_cols, + &result_col_type, &result_component_type)) { + assert(0); + } + + if (result_num_cols + 2 != num_operands) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected total number of Constituents to be equal " + << "to the number of columns of Result Type matrix"; + } + + for (uint32_t operand_index = 2; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + if (operand_type != result_col_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Constituent type to be equal to the column " + << "type Result Type matrix"; + } + } + + break; + } + case SpvOpTypeArray: { + const Instruction* const array_inst = _.FindDef(result_type); + assert(array_inst); + assert(array_inst->opcode() == SpvOpTypeArray); + + auto size = _.FindDef(array_inst->word(3)); + if (spvOpcodeIsSpecConstant(size->opcode())) { + // Cannot verify against the size of this array. + break; + } + + uint64_t array_size = 0; + if (!_.GetConstantValUint64(array_inst->word(3), &array_size)) { + assert(0 && "Array type definition is corrupt"); + } + + if (array_size + 2 != num_operands) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected total number of Constituents to be equal " + << "to the number of elements of Result Type array"; + } + + const uint32_t result_component_type = array_inst->word(2); + for (uint32_t operand_index = 2; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + if (operand_type != result_component_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Constituent type to be equal to the column " + << "type Result Type array"; + } + } + + break; + } + case SpvOpTypeStruct: { + const Instruction* const struct_inst = _.FindDef(result_type); + assert(struct_inst); + assert(struct_inst->opcode() == SpvOpTypeStruct); + + if (struct_inst->operands().size() + 1 != num_operands) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected total number of Constituents to be equal " + << "to the number of members of Result Type struct"; + } + + for (uint32_t operand_index = 2; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + const uint32_t member_type = struct_inst->word(operand_index); + if (operand_type != member_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Constituent type to be equal to the " + << "corresponding member type of Result Type struct"; + } + } + + break; + } + default: { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be a composite type"; + } + } + return SPV_SUCCESS; +} + +spv_result_t ValidateCompositeExtract(ValidationState_t& _, + const Instruction* inst) { + uint32_t member_type = 0; + if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) { + return error; + } + + const uint32_t result_type = inst->type_id(); + if (result_type != member_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result type (Op" << spvOpcodeString(_.GetIdOpcode(result_type)) + << ") does not match the type that results from indexing into " + "the composite (Op" + << spvOpcodeString(_.GetIdOpcode(member_type)) << ")."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateCompositeInsert(ValidationState_t& _, + const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t object_type = _.GetOperandTypeId(inst, 2); + const uint32_t composite_type = _.GetOperandTypeId(inst, 3); + const uint32_t result_type = inst->type_id(); + if (result_type != composite_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The Result Type must be the same as Composite type in Op" + << spvOpcodeString(opcode) << " yielding Result Id " << result_type + << "."; + } + + uint32_t member_type = 0; + if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) { + return error; + } + + if (object_type != member_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The Object type (Op" + << spvOpcodeString(_.GetIdOpcode(object_type)) + << ") does not match the type that results from indexing into the " + "Composite (Op" + << spvOpcodeString(_.GetIdOpcode(member_type)) << ")."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateCopyObject(ValidationState_t& _, const Instruction* inst) { + const uint32_t result_type = inst->type_id(); + const uint32_t operand_type = _.GetOperandTypeId(inst, 2); + if (operand_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type and Operand type to be the same"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTranspose(ValidationState_t& _, const Instruction* inst) { + uint32_t result_num_rows = 0; + uint32_t result_num_cols = 0; + uint32_t result_col_type = 0; + uint32_t result_component_type = 0; + const uint32_t result_type = inst->type_id(); + if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, &result_num_cols, + &result_col_type, &result_component_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be a matrix type"; + } + + const uint32_t matrix_type = _.GetOperandTypeId(inst, 2); + uint32_t matrix_num_rows = 0; + uint32_t matrix_num_cols = 0; + uint32_t matrix_col_type = 0; + uint32_t matrix_component_type = 0; + if (!_.GetMatrixTypeInfo(matrix_type, &matrix_num_rows, &matrix_num_cols, + &matrix_col_type, &matrix_component_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Matrix to be of type OpTypeMatrix"; + } + + if (result_component_type != matrix_component_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected component types of Matrix and Result Type to be " + << "identical"; + } + + if (result_num_rows != matrix_num_cols || + result_num_cols != matrix_num_rows) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected number of columns and the column size of Matrix " + << "to be the reverse of those of Result Type"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateVectorShuffle(ValidationState_t& _, + const Instruction* inst) { + auto resultType = _.FindDef(inst->type_id()); + if (!resultType || resultType->opcode() != SpvOpTypeVector) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The Result Type of OpVectorShuffle must be" + << " OpTypeVector. Found Op" + << spvOpcodeString(static_cast(resultType->opcode())) << "."; + } + + // The number of components in Result Type must be the same as the number of + // Component operands. + auto componentCount = inst->operands().size() - 4; + auto resultVectorDimension = resultType->GetOperandAs(2); + if (componentCount != resultVectorDimension) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpVectorShuffle component literals count does not match " + "Result Type '" + << _.getIdName(resultType->id()) << "'s vector component count."; + } + + // Vector 1 and Vector 2 must both have vector types, with the same Component + // Type as Result Type. + auto vector1Object = _.FindDef(inst->GetOperandAs(2)); + auto vector1Type = _.FindDef(vector1Object->type_id()); + auto vector2Object = _.FindDef(inst->GetOperandAs(3)); + auto vector2Type = _.FindDef(vector2Object->type_id()); + if (!vector1Type || vector1Type->opcode() != SpvOpTypeVector) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The type of Vector 1 must be OpTypeVector."; + } + if (!vector2Type || vector2Type->opcode() != SpvOpTypeVector) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The type of Vector 2 must be OpTypeVector."; + } + + auto resultComponentType = resultType->GetOperandAs(1); + if (vector1Type->GetOperandAs(1) != resultComponentType) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The Component Type of Vector 1 must be the same as ResultType."; + } + if (vector2Type->GetOperandAs(1) != resultComponentType) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The Component Type of Vector 2 must be the same as ResultType."; + } + + // All Component literals must either be FFFFFFFF or in [0, N - 1]. + auto vector1ComponentCount = vector1Type->GetOperandAs(2); + auto vector2ComponentCount = vector2Type->GetOperandAs(2); + auto N = vector1ComponentCount + vector2ComponentCount; + auto firstLiteralIndex = 4; + for (size_t i = firstLiteralIndex; i < inst->operands().size(); ++i) { + auto literal = inst->GetOperandAs(i); + if (literal != 0xFFFFFFFF && literal >= N) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Component index " << literal << " is out of bounds for " + << "combined (Vector1 + Vector2) size of " << N << "."; + } + } + + return SPV_SUCCESS; +} + +} // anonymous namespace + +// Validates correctness of composite instructions. +spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpVectorExtractDynamic: + return ValidateVectorExtractDynamic(_, inst); + case SpvOpVectorInsertDynamic: + return ValidateVectorInsertDyanmic(_, inst); + case SpvOpVectorShuffle: + return ValidateVectorShuffle(_, inst); + case SpvOpCompositeConstruct: + return ValidateCompositeConstruct(_, inst); + case SpvOpCompositeExtract: + return ValidateCompositeExtract(_, inst); + case SpvOpCompositeInsert: + return ValidateCompositeInsert(_, inst); + case SpvOpCopyObject: + return ValidateCopyObject(_, inst); + case SpvOpTranspose: + return ValidateTranspose(_, inst); + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_constants.cpp b/3rdparty/spirv-tools/source/val/validate_constants.cpp new file mode 100644 index 000000000..5dbe6c6df --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_constants.cpp @@ -0,0 +1,346 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateConstantBool(ValidationState_t& _, + const Instruction* inst) { + auto type = _.FindDef(inst->type_id()); + if (!type || type->opcode() != SpvOpTypeBool) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Op" << spvOpcodeString(inst->opcode()) << " Result Type '" + << _.getIdName(inst->type_id()) << "' is not a boolean type."; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateConstantComposite(ValidationState_t& _, + const Instruction* inst) { + std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode()); + + const auto result_type = _.FindDef(inst->type_id()); + if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Result Type '" + << _.getIdName(inst->type_id()) << "' is not a composite type."; + } + + const auto constituent_count = inst->words().size() - 3; + switch (result_type->opcode()) { + case SpvOpTypeVector: { + const auto component_count = result_type->GetOperandAs(2); + if (component_count != constituent_count) { + // TODO: Output ID's on diagnostic + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name + << " Constituent count does not match " + "Result Type '" + << _.getIdName(result_type->id()) + << "'s vector component count."; + } + const auto component_type = + _.FindDef(result_type->GetOperandAs(1)); + if (!component_type) { + return _.diag(SPV_ERROR_INVALID_ID, result_type) + << "Component type is not defined."; + } + for (size_t constituent_index = 2; + constituent_index < inst->operands().size(); constituent_index++) { + const auto constituent_id = + inst->GetOperandAs(constituent_index); + const auto constituent = _.FindDef(constituent_id); + if (!constituent || + !spvOpcodeIsConstantOrUndef(constituent->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' is not a constant or undef."; + } + const auto constituent_result_type = _.FindDef(constituent->type_id()); + if (!constituent_result_type || + component_type->opcode() != constituent_result_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "'s type does not match Result Type '" + << _.getIdName(result_type->id()) << "'s vector element type."; + } + } + } break; + case SpvOpTypeMatrix: { + const auto column_count = result_type->GetOperandAs(2); + if (column_count != constituent_count) { + // TODO: Output ID's on diagnostic + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name + << " Constituent count does not match " + "Result Type '" + << _.getIdName(result_type->id()) << "'s matrix column count."; + } + + const auto column_type = _.FindDef(result_type->words()[2]); + if (!column_type) { + return _.diag(SPV_ERROR_INVALID_ID, result_type) + << "Column type is not defined."; + } + const auto component_count = column_type->GetOperandAs(2); + const auto component_type = + _.FindDef(column_type->GetOperandAs(1)); + if (!component_type) { + return _.diag(SPV_ERROR_INVALID_ID, column_type) + << "Component type is not defined."; + } + + for (size_t constituent_index = 2; + constituent_index < inst->operands().size(); constituent_index++) { + const auto constituent_id = + inst->GetOperandAs(constituent_index); + const auto constituent = _.FindDef(constituent_id); + if (!constituent || + !(SpvOpConstantComposite == constituent->opcode() || + SpvOpSpecConstantComposite == constituent->opcode() || + SpvOpUndef == constituent->opcode())) { + // The message says "... or undef" because the spec does not say + // undef is a constant. + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' is not a constant composite or undef."; + } + const auto vector = _.FindDef(constituent->type_id()); + if (!vector) { + return _.diag(SPV_ERROR_INVALID_ID, constituent) + << "Result type is not defined."; + } + if (column_type->opcode() != vector->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' type does not match Result Type '" + << _.getIdName(result_type->id()) << "'s matrix column type."; + } + const auto vector_component_type = + _.FindDef(vector->GetOperandAs(1)); + if (component_type->id() != vector_component_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' component type does not match Result Type '" + << _.getIdName(result_type->id()) + << "'s matrix column component type."; + } + if (component_count != vector->words()[3]) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' vector component count does not match Result Type '" + << _.getIdName(result_type->id()) + << "'s vector component count."; + } + } + } break; + case SpvOpTypeArray: { + auto element_type = _.FindDef(result_type->GetOperandAs(1)); + if (!element_type) { + return _.diag(SPV_ERROR_INVALID_ID, result_type) + << "Element type is not defined."; + } + const auto length = _.FindDef(result_type->GetOperandAs(2)); + if (!length) { + return _.diag(SPV_ERROR_INVALID_ID, result_type) + << "Length is not defined."; + } + bool is_int32; + bool is_const; + uint32_t value; + std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id()); + if (is_int32 && is_const && !spvOpcodeIsSpecConstant(length->opcode()) && + value != constituent_count) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name + << " Constituent count does not match " + "Result Type '" + << _.getIdName(result_type->id()) << "'s array length."; + } + for (size_t constituent_index = 2; + constituent_index < inst->operands().size(); constituent_index++) { + const auto constituent_id = + inst->GetOperandAs(constituent_index); + const auto constituent = _.FindDef(constituent_id); + if (!constituent || + !spvOpcodeIsConstantOrUndef(constituent->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' is not a constant or undef."; + } + const auto constituent_type = _.FindDef(constituent->type_id()); + if (!constituent_type) { + return _.diag(SPV_ERROR_INVALID_ID, constituent) + << "Result type is not defined."; + } + if (element_type->id() != constituent_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "'s type does not match Result Type '" + << _.getIdName(result_type->id()) << "'s array element type."; + } + } + } break; + case SpvOpTypeStruct: { + const auto member_count = result_type->words().size() - 2; + if (member_count != constituent_count) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(inst->type_id()) + << "' count does not match Result Type '" + << _.getIdName(result_type->id()) << "'s struct member count."; + } + for (uint32_t constituent_index = 2, member_index = 1; + constituent_index < inst->operands().size(); + constituent_index++, member_index++) { + const auto constituent_id = + inst->GetOperandAs(constituent_index); + const auto constituent = _.FindDef(constituent_id); + if (!constituent || + !spvOpcodeIsConstantOrUndef(constituent->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' is not a constant or undef."; + } + const auto constituent_type = _.FindDef(constituent->type_id()); + if (!constituent_type) { + return _.diag(SPV_ERROR_INVALID_ID, constituent) + << "Result type is not defined."; + } + + const auto member_type_id = + result_type->GetOperandAs(member_index); + const auto member_type = _.FindDef(member_type_id); + if (!member_type || member_type->id() != constituent_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' type does not match the Result Type '" + << _.getIdName(result_type->id()) << "'s member type."; + } + } + } break; + default: + break; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateConstantSampler(ValidationState_t& _, + const Instruction* inst) { + const auto result_type = _.FindDef(inst->type_id()); + if (!result_type || result_type->opcode() != SpvOpTypeSampler) { + return _.diag(SPV_ERROR_INVALID_ID, result_type) + << "OpConstantSampler Result Type '" + << _.getIdName(inst->type_id()) << "' is not a sampler type."; + } + + return SPV_SUCCESS; +} + +// True if instruction defines a type that can have a null value, as defined by +// the SPIR-V spec. Tracks composite-type components through module to check +// nullability transitively. +bool IsTypeNullable(const std::vector& instruction, + const ValidationState_t& _) { + uint16_t opcode; + uint16_t word_count; + spvOpcodeSplit(instruction[0], &word_count, &opcode); + switch (static_cast(opcode)) { + case SpvOpTypeBool: + case SpvOpTypeInt: + case SpvOpTypeFloat: + case SpvOpTypePointer: + case SpvOpTypeEvent: + case SpvOpTypeDeviceEvent: + case SpvOpTypeReserveId: + case SpvOpTypeQueue: + return true; + case SpvOpTypeArray: + case SpvOpTypeMatrix: + case SpvOpTypeVector: { + auto base_type = _.FindDef(instruction[2]); + return base_type && IsTypeNullable(base_type->words(), _); + } + case SpvOpTypeStruct: { + for (size_t elementIndex = 2; elementIndex < instruction.size(); + ++elementIndex) { + auto element = _.FindDef(instruction[elementIndex]); + if (!element || !IsTypeNullable(element->words(), _)) return false; + } + return true; + } + default: + return false; + } +} + +spv_result_t ValidateConstantNull(ValidationState_t& _, + const Instruction* inst) { + const auto result_type = _.FindDef(inst->type_id()); + if (!result_type || !IsTypeNullable(result_type->words(), _)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpConstantNull Result Type '" + << _.getIdName(inst->type_id()) << "' cannot have a null value."; + } + + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpConstantTrue: + case SpvOpConstantFalse: + case SpvOpSpecConstantTrue: + case SpvOpSpecConstantFalse: + if (auto error = ValidateConstantBool(_, inst)) return error; + break; + case SpvOpConstantComposite: + case SpvOpSpecConstantComposite: + if (auto error = ValidateConstantComposite(_, inst)) return error; + break; + case SpvOpConstantSampler: + if (auto error = ValidateConstantSampler(_, inst)) return error; + break; + case SpvOpConstantNull: + if (auto error = ValidateConstantNull(_, inst)) return error; + break; + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_conversion.cpp b/3rdparty/spirv-tools/source/val/validate_conversion.cpp similarity index 81% rename from 3rdparty/spirv-tools/source/validate_conversion.cpp rename to 3rdparty/spirv-tools/source/val/validate_conversion.cpp index 75ff6a2e1..9c6f68c6f 100644 --- a/3rdparty/spirv-tools/source/validate_conversion.cpp +++ b/3rdparty/spirv-tools/source/val/validate_conversion.cpp @@ -14,38 +14,38 @@ // Validates correctness of conversion instructions. -#include "validate.h" +#include "source/val/validate.h" -#include "diagnostic.h" -#include "opcode.h" -#include "val/instruction.h" -#include "val/validation_state.h" +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" -namespace libspirv { +namespace spvtools { +namespace val { // Validates correctness of conversion instructions. -spv_result_t ConversionPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - const uint32_t result_type = inst->type_id; +spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); switch (opcode) { case SpvOpConvertFToU: { if (!_.IsUnsignedIntScalarType(result_type) && !_.IsUnsignedIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected unsigned int scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsFloatScalarType(input_type) && !_.IsFloatVectorType(input_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be float scalar or vector: " << spvOpcodeString(opcode); if (_.GetDimension(result_type) != _.GetDimension(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have the same dimension as Result Type: " << spvOpcodeString(opcode); @@ -54,19 +54,19 @@ spv_result_t ConversionPass(ValidationState_t& _, case SpvOpConvertFToS: { if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsFloatScalarType(input_type) && !_.IsFloatVectorType(input_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be float scalar or vector: " << spvOpcodeString(opcode); if (_.GetDimension(result_type) != _.GetDimension(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have the same dimension as Result Type: " << spvOpcodeString(opcode); @@ -77,19 +77,19 @@ spv_result_t ConversionPass(ValidationState_t& _, case SpvOpConvertUToF: { if (!_.IsFloatScalarType(result_type) && !_.IsFloatVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be int scalar or vector: " << spvOpcodeString(opcode); if (_.GetDimension(result_type) != _.GetDimension(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have the same dimension as Result Type: " << spvOpcodeString(opcode); @@ -99,24 +99,24 @@ spv_result_t ConversionPass(ValidationState_t& _, case SpvOpUConvert: { if (!_.IsUnsignedIntScalarType(result_type) && !_.IsUnsignedIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected unsigned int scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be int scalar or vector: " << spvOpcodeString(opcode); if (_.GetDimension(result_type) != _.GetDimension(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have the same dimension as Result Type: " << spvOpcodeString(opcode); if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have different bit width from Result " "Type: " << spvOpcodeString(opcode); @@ -125,24 +125,24 @@ spv_result_t ConversionPass(ValidationState_t& _, case SpvOpSConvert: { if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be int scalar or vector: " << spvOpcodeString(opcode); if (_.GetDimension(result_type) != _.GetDimension(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have the same dimension as Result Type: " << spvOpcodeString(opcode); if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have different bit width from Result " "Type: " << spvOpcodeString(opcode); @@ -152,24 +152,24 @@ spv_result_t ConversionPass(ValidationState_t& _, case SpvOpFConvert: { if (!_.IsFloatScalarType(result_type) && !_.IsFloatVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsFloatScalarType(input_type) && !_.IsFloatVectorType(input_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be float scalar or vector: " << spvOpcodeString(opcode); if (_.GetDimension(result_type) != _.GetDimension(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have the same dimension as Result Type: " << spvOpcodeString(opcode); if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have different bit width from Result " "Type: " << spvOpcodeString(opcode); @@ -180,13 +180,13 @@ spv_result_t ConversionPass(ValidationState_t& _, if ((!_.IsFloatScalarType(result_type) && !_.IsFloatVectorType(result_type)) || _.GetBitWidth(result_type) != 32) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected 32-bit float scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (input_type != result_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input type to be equal to Result Type: " << spvOpcodeString(opcode); break; @@ -194,13 +194,13 @@ spv_result_t ConversionPass(ValidationState_t& _, case SpvOpConvertPtrToU: { if (!_.IsUnsignedIntScalarType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected unsigned int scalar type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!_.IsPointerType(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be a pointer: " << spvOpcodeString(opcode); break; } @@ -208,19 +208,19 @@ spv_result_t ConversionPass(ValidationState_t& _, case SpvOpSatConvertSToU: case SpvOpSatConvertUToS: { if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector as input: " << spvOpcodeString(opcode); if (_.GetDimension(result_type) != _.GetDimension(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have the same dimension as Result Type: " << spvOpcodeString(opcode); break; @@ -228,13 +228,13 @@ spv_result_t ConversionPass(ValidationState_t& _, case SpvOpConvertUToPtr: { if (!_.IsPointerType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type to be a pointer: " << spvOpcodeString(opcode); const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!_.IsIntScalarType(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar as input: " << spvOpcodeString(opcode); break; } @@ -244,12 +244,12 @@ spv_result_t ConversionPass(ValidationState_t& _, uint32_t result_data_type = 0; if (!_.GetPointerTypeInfo(result_type, &result_data_type, &result_storage_class)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type to be a pointer: " << spvOpcodeString(opcode); if (result_storage_class != SpvStorageClassGeneric) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type to have storage class Generic: " << spvOpcodeString(opcode); @@ -258,18 +258,18 @@ spv_result_t ConversionPass(ValidationState_t& _, uint32_t input_data_type = 0; if (!_.GetPointerTypeInfo(input_type, &input_data_type, &input_storage_class)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be a pointer: " << spvOpcodeString(opcode); if (input_storage_class != SpvStorageClassWorkgroup && input_storage_class != SpvStorageClassCrossWorkgroup && input_storage_class != SpvStorageClassFunction) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have storage class Workgroup, " << "CrossWorkgroup or Function: " << spvOpcodeString(opcode); if (result_data_type != input_data_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input and Result Type to point to the same type: " << spvOpcodeString(opcode); break; @@ -280,14 +280,14 @@ spv_result_t ConversionPass(ValidationState_t& _, uint32_t result_data_type = 0; if (!_.GetPointerTypeInfo(result_type, &result_data_type, &result_storage_class)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type to be a pointer: " << spvOpcodeString(opcode); if (result_storage_class != SpvStorageClassWorkgroup && result_storage_class != SpvStorageClassCrossWorkgroup && result_storage_class != SpvStorageClassFunction) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type to have storage class Workgroup, " << "CrossWorkgroup or Function: " << spvOpcodeString(opcode); @@ -296,16 +296,16 @@ spv_result_t ConversionPass(ValidationState_t& _, uint32_t input_data_type = 0; if (!_.GetPointerTypeInfo(input_type, &input_data_type, &input_storage_class)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be a pointer: " << spvOpcodeString(opcode); if (input_storage_class != SpvStorageClassGeneric) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have storage class Generic: " << spvOpcodeString(opcode); if (result_data_type != input_data_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input and Result Type to point to the same type: " << spvOpcodeString(opcode); break; @@ -316,13 +316,13 @@ spv_result_t ConversionPass(ValidationState_t& _, uint32_t result_data_type = 0; if (!_.GetPointerTypeInfo(result_type, &result_data_type, &result_storage_class)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type to be a pointer: " << spvOpcodeString(opcode); - const uint32_t target_storage_class = inst->words[4]; + const uint32_t target_storage_class = inst->word(4); if (result_storage_class != target_storage_class) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type to be of target storage class: " << spvOpcodeString(opcode); @@ -331,23 +331,23 @@ spv_result_t ConversionPass(ValidationState_t& _, uint32_t input_data_type = 0; if (!_.GetPointerTypeInfo(input_type, &input_data_type, &input_storage_class)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be a pointer: " << spvOpcodeString(opcode); if (input_storage_class != SpvStorageClassGeneric) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have storage class Generic: " << spvOpcodeString(opcode); if (result_data_type != input_data_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input and Result Type to point to the same type: " << spvOpcodeString(opcode); if (target_storage_class != SpvStorageClassWorkgroup && target_storage_class != SpvStorageClassCrossWorkgroup && target_storage_class != SpvStorageClassFunction) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected target storage class to be Workgroup, " << "CrossWorkgroup or Function: " << spvOpcodeString(opcode); break; @@ -356,7 +356,7 @@ spv_result_t ConversionPass(ValidationState_t& _, case SpvOpBitcast: { const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have a type: " << spvOpcodeString(opcode); const bool result_is_pointer = _.IsPointerType(result_type); @@ -368,24 +368,24 @@ spv_result_t ConversionPass(ValidationState_t& _, !_.IsIntVectorType(result_type) && !_.IsFloatScalarType(result_type) && !_.IsFloatVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type to be a pointer or int or float vector " << "or scalar type: " << spvOpcodeString(opcode); if (!input_is_pointer && !input_is_int_scalar && !_.IsIntVectorType(input_type) && !_.IsFloatScalarType(input_type) && !_.IsFloatVectorType(input_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be a pointer or int or float vector " << "or scalar: " << spvOpcodeString(opcode); if (result_is_pointer && !input_is_pointer && !input_is_int_scalar) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be a pointer or int scalar if Result Type " << "is pointer: " << spvOpcodeString(opcode); if (input_is_pointer && !result_is_pointer && !result_is_int_scalar) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Pointer can only be converted to another pointer or int " << "scalar: " << spvOpcodeString(opcode); @@ -395,7 +395,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const uint32_t input_size = _.GetBitWidth(input_type) * _.GetDimension(input_type); if (result_size != input_size) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to have the same total bit width as " << "Result Type: " << spvOpcodeString(opcode); } @@ -409,4 +409,5 @@ spv_result_t ConversionPass(ValidationState_t& _, return SPV_SUCCESS; } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_datarules.cpp b/3rdparty/spirv-tools/source/val/validate_datarules.cpp similarity index 69% rename from 3rdparty/spirv-tools/source/validate_datarules.cpp rename to 3rdparty/spirv-tools/source/val/validate_datarules.cpp index 7d51ecaba..129b6bbf9 100644 --- a/3rdparty/spirv-tools/source/validate_datarules.cpp +++ b/3rdparty/spirv-tools/source/val/validate_datarules.cpp @@ -14,31 +14,29 @@ // Ensures Data Rules are followed according to the specifications. -#include "validate.h" +#include "source/val/validate.h" #include #include #include -#include "diagnostic.h" -#include "opcode.h" -#include "operand.h" -#include "val/instruction.h" -#include "val/validation_state.h" - -using libspirv::CapabilitySet; -using libspirv::DiagnosticStream; -using libspirv::ValidationState_t; +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" +namespace spvtools { +namespace val { namespace { // Validates that the number of components in the vector is valid. // Vector types can only be parameterized as having 2, 3, or 4 components. // If the Vector16 capability is added, 8 and 16 components are also allowed. spv_result_t ValidateVecNumComponents(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { + const Instruction* inst) { // Operand 2 specifies the number of components in the vector. - const uint32_t num_components = inst->words[inst->operands[2].offset]; + auto num_components = inst->GetOperandAs(2); if (num_components == 2 || num_components == 3 || num_components == 4) { return SPV_SUCCESS; } @@ -46,14 +44,14 @@ spv_result_t ValidateVecNumComponents(ValidationState_t& _, if (_.HasCapability(SpvCapabilityVector16)) { return SPV_SUCCESS; } - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Having " << num_components << " components for " - << spvOpcodeString(static_cast(inst->opcode)) + << spvOpcodeString(inst->opcode()) << " requires the Vector16 capability"; } - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Illegal number of components (" << num_components << ") for " - << spvOpcodeString(static_cast(inst->opcode)); + << spvOpcodeString(inst->opcode()); } // Validates that the number of bits specifed for a float type is valid. @@ -61,10 +59,9 @@ spv_result_t ValidateVecNumComponents(ValidationState_t& _, // Float16 capability allows using a 16-bit OpTypeFloat. // Float16Buffer capability allows creation of a 16-bit OpTypeFloat. // Float64 capability allows using a 64-bit OpTypeFloat. -spv_result_t ValidateFloatSize(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { +spv_result_t ValidateFloatSize(ValidationState_t& _, const Instruction* inst) { // Operand 1 is the number of bits for this float - const uint32_t num_bits = inst->words[inst->operands[1].offset]; + auto num_bits = inst->GetOperandAs(1); if (num_bits == 32) { return SPV_SUCCESS; } @@ -72,7 +69,7 @@ spv_result_t ValidateFloatSize(ValidationState_t& _, if (_.features().declare_float16_type) { return SPV_SUCCESS; } - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Using a 16-bit floating point " << "type requires the Float16 or Float16Buffer capability," " or an extension that explicitly enables 16-bit floating point."; @@ -81,11 +78,11 @@ spv_result_t ValidateFloatSize(ValidationState_t& _, if (_.HasCapability(SpvCapabilityFloat64)) { return SPV_SUCCESS; } - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Using a 64-bit floating point " << "type requires the Float64 capability."; } - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Invalid number of bits (" << num_bits << ") used for OpTypeFloat."; } @@ -93,25 +90,25 @@ spv_result_t ValidateFloatSize(ValidationState_t& _, // Scalar integer types can be parameterized only with 32-bits. // Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit // integers, respectively. -spv_result_t ValidateIntSize(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { +spv_result_t ValidateIntSize(ValidationState_t& _, const Instruction* inst) { // Operand 1 is the number of bits for this integer. - const uint32_t num_bits = inst->words[inst->operands[1].offset]; + auto num_bits = inst->GetOperandAs(1); if (num_bits == 32) { return SPV_SUCCESS; } if (num_bits == 8) { - if (_.HasCapability(SpvCapabilityInt8)) { + if (_.features().declare_int8_type) { return SPV_SUCCESS; } - return _.diag(SPV_ERROR_INVALID_DATA) - << "Using an 8-bit integer type requires the Int8 capability."; + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Using an 8-bit integer type requires the Int8 capability," + " or an extension that explicitly enables 8-bit integers."; } if (num_bits == 16) { if (_.features().declare_int16_type) { return SPV_SUCCESS; } - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Using a 16-bit integer type requires the Int16 capability," " or an extension that explicitly enables 16-bit integers."; } @@ -119,22 +116,22 @@ spv_result_t ValidateIntSize(ValidationState_t& _, if (_.HasCapability(SpvCapabilityInt64)) { return SPV_SUCCESS; } - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Using a 64-bit integer type requires the Int64 capability."; } - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Invalid number of bits (" << num_bits << ") used for OpTypeInt."; } // Validates that the matrix is parameterized with floating-point types. spv_result_t ValidateMatrixColumnType(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { + const Instruction* inst) { // Find the component type of matrix columns (must be vector). // Operand 1 is the of the type specified for matrix columns. - auto type_id = inst->words[inst->operands[1].offset]; + auto type_id = inst->GetOperandAs(1); auto col_type_instr = _.FindDef(type_id); if (col_type_instr->opcode() != SpvOpTypeVector) { - return _.diag(SPV_ERROR_INVALID_ID) + return _.diag(SPV_ERROR_INVALID_ID, inst) << "Columns in a matrix must be of type vector."; } @@ -144,73 +141,72 @@ spv_result_t ValidateMatrixColumnType(ValidationState_t& _, col_type_instr->words()[col_type_instr->operands()[1].offset]; auto comp_type_instruction = _.FindDef(comp_type_id); if (comp_type_instruction->opcode() != SpvOpTypeFloat) { - return _.diag(SPV_ERROR_INVALID_DATA) << "Matrix types can only be " - "parameterized with " - "floating-point types."; + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be " + "parameterized with " + "floating-point types."; } return SPV_SUCCESS; } // Validates that the matrix has 2,3, or 4 columns. spv_result_t ValidateMatrixNumCols(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { + const Instruction* inst) { // Operand 2 is the number of columns in the matrix. - const uint32_t num_cols = inst->words[inst->operands[2].offset]; + auto num_cols = inst->GetOperandAs(2); if (num_cols != 2 && num_cols != 3 && num_cols != 4) { - return _.diag(SPV_ERROR_INVALID_DATA) << "Matrix types can only be " - "parameterized as having only 2, " - "3, or 4 columns."; + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be " + "parameterized as having " + "only 2, 3, or 4 columns."; } return SPV_SUCCESS; } // Validates that OpSpecConstant specializes to either int or float type. spv_result_t ValidateSpecConstNumerical(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { + const Instruction* inst) { // Operand 0 is the of the type that we're specializing to. - auto type_id = inst->words[inst->operands[0].offset]; + auto type_id = inst->GetOperandAs(0); auto type_instruction = _.FindDef(type_id); auto type_opcode = type_instruction->opcode(); if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) { - return _.diag(SPV_ERROR_INVALID_DATA) << "Specialization constant must be " - "an integer or floating-point " - "number."; + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant " + "must be an integer or " + "floating-point number."; } return SPV_SUCCESS; } // Validates that OpSpecConstantTrue and OpSpecConstantFalse specialize to bool. spv_result_t ValidateSpecConstBoolean(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { + const Instruction* inst) { // Find out the type that we're specializing to. - auto type_instruction = _.FindDef(inst->type_id); + auto type_instruction = _.FindDef(inst->type_id()); if (type_instruction->opcode() != SpvOpTypeBool) { - return _.diag(SPV_ERROR_INVALID_ID) << "Specialization constant must be " - "a boolean type."; + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Specialization constant must be a boolean type."; } return SPV_SUCCESS; } // Records the of the forward pointer to be used for validation. spv_result_t ValidateForwardPointer(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { + const Instruction* inst) { // Record the (which is operand 0) to ensure it's used properly. // OpTypeStruct can only include undefined pointers that are // previously declared as a ForwardPointer - return (_.RegisterForwardPointer(inst->words[inst->operands[0].offset])); + return (_.RegisterForwardPointer(inst->GetOperandAs(0))); } // Validates that any undefined component of the struct is a forward pointer. // It is valid to declare a forward pointer, and use its as one of the // components of a struct. -spv_result_t ValidateStruct(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { +spv_result_t ValidateStruct(ValidationState_t& _, const Instruction* inst) { // Struct components are operands 1, 2, etc. - for (unsigned i = 1; i < inst->num_operands; i++) { - auto type_id = inst->words[inst->operands[i].offset]; + for (unsigned i = 1; i < inst->operands().size(); i++) { + auto type_id = inst->GetOperandAs(i); auto type_instruction = _.FindDef(type_id); if (type_instruction == nullptr && !_.IsForwardPointer(type_id)) { - return _.diag(SPV_ERROR_INVALID_ID) + return _.diag(SPV_ERROR_INVALID_ID, inst) << "Forward reference operands in an OpTypeStruct must first be " "declared using OpTypeForwardPointer."; } @@ -218,15 +214,12 @@ spv_result_t ValidateStruct(ValidationState_t& _, return SPV_SUCCESS; } -} // anonymous namespace - -namespace libspirv { +} // namespace // Validates that Data Rules are followed according to the specifications. // (Data Rules subsection of 2.16.1 Universal Validation Rules) -spv_result_t DataRulesPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - switch (inst->opcode) { +spv_result_t DataRulesPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { case SpvOpTypeVector: { if (auto error = ValidateVecNumComponents(_, inst)) return error; break; @@ -270,4 +263,5 @@ spv_result_t DataRulesPass(ValidationState_t& _, return SPV_SUCCESS; } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_debug.cpp b/3rdparty/spirv-tools/source/val/validate_debug.cpp new file mode 100644 index 000000000..d84ed3801 --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_debug.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateMemberName(ValidationState_t& _, const Instruction* inst) { + const auto type_id = inst->GetOperandAs(0); + const auto type = _.FindDef(type_id); + if (!type || SpvOpTypeStruct != type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpMemberName Type '" << _.getIdName(type_id) + << "' is not a struct type."; + } + const auto member_id = inst->GetOperandAs(1); + const auto member_count = (uint32_t)(type->words().size() - 2); + if (member_count <= member_id) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpMemberName Member '" << _.getIdName(member_id) + << "' index is larger than Type '" << _.getIdName(type->id()) + << "'s member count."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateLine(ValidationState_t& _, const Instruction* inst) { + const auto file_id = inst->GetOperandAs(0); + const auto file = _.FindDef(file_id); + if (!file || SpvOpString != file->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpLine Target '" << _.getIdName(file_id) + << "' is not an OpString."; + } + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t DebugPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpMemberName: + if (auto error = ValidateMemberName(_, inst)) return error; + break; + case SpvOpLine: + if (auto error = ValidateLine(_, inst)) return error; + break; + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_decorations.cpp b/3rdparty/spirv-tools/source/val/validate_decorations.cpp new file mode 100644 index 000000000..ed312f73f --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_decorations.cpp @@ -0,0 +1,856 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include +#include +#include +#include +#include +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/spirv_validator_options.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// Distinguish between row and column major matrix layouts. +enum MatrixLayout { kRowMajor, kColumnMajor }; + +// A functor for hashing a pair of integers. +struct PairHash { + std::size_t operator()(const std::pair pair) const { + const uint32_t a = pair.first; + const uint32_t b = pair.second; + const uint32_t rotated_b = (b >> 2) | ((b & 3) << 30); + return a ^ rotated_b; + } +}; + +// Struct member layout attributes that are inherited through arrays. +struct LayoutConstraints { + explicit LayoutConstraints( + MatrixLayout the_majorness = MatrixLayout::kColumnMajor, + uint32_t stride = 0) + : majorness(the_majorness), matrix_stride(stride) {} + MatrixLayout majorness; + uint32_t matrix_stride; +}; + +// A type for mapping (struct id, member id) to layout constraints. +using MemberConstraints = std::unordered_map, + LayoutConstraints, PairHash>; + +// Returns the array stride of the given array type. +uint32_t GetArrayStride(uint32_t array_id, ValidationState_t& vstate) { + for (auto& decoration : vstate.id_decorations(array_id)) { + if (SpvDecorationArrayStride == decoration.dec_type()) { + return decoration.params()[0]; + } + } + return 0; +} + +// Returns true if the given variable has a BuiltIn decoration. +bool isBuiltInVar(uint32_t var_id, ValidationState_t& vstate) { + const auto& decorations = vstate.id_decorations(var_id); + return std::any_of( + decorations.begin(), decorations.end(), + [](const Decoration& d) { return SpvDecorationBuiltIn == d.dec_type(); }); +} + +// Returns true if the given structure type has any members with BuiltIn +// decoration. +bool isBuiltInStruct(uint32_t struct_id, ValidationState_t& vstate) { + const auto& decorations = vstate.id_decorations(struct_id); + return std::any_of( + decorations.begin(), decorations.end(), [](const Decoration& d) { + return SpvDecorationBuiltIn == d.dec_type() && + Decoration::kInvalidMember != d.struct_member_index(); + }); +} + +// Returns true if the given ID has the Import LinkageAttributes decoration. +bool hasImportLinkageAttribute(uint32_t id, ValidationState_t& vstate) { + const auto& decorations = vstate.id_decorations(id); + return std::any_of(decorations.begin(), decorations.end(), + [](const Decoration& d) { + return SpvDecorationLinkageAttributes == d.dec_type() && + d.params().size() >= 2u && + d.params().back() == SpvLinkageTypeImport; + }); +} + +// Returns a vector of all members of a structure. +std::vector getStructMembers(uint32_t struct_id, + ValidationState_t& vstate) { + const auto inst = vstate.FindDef(struct_id); + return std::vector(inst->words().begin() + 2, inst->words().end()); +} + +// Returns a vector of all members of a structure that have specific type. +std::vector getStructMembers(uint32_t struct_id, SpvOp type, + ValidationState_t& vstate) { + std::vector members; + for (auto id : getStructMembers(struct_id, vstate)) { + if (type == vstate.FindDef(id)->opcode()) { + members.push_back(id); + } + } + return members; +} + +// Returns whether the given structure is missing Offset decoration for any +// member. Handles also nested structures. +bool isMissingOffsetInStruct(uint32_t struct_id, ValidationState_t& vstate) { + std::vector hasOffset(getStructMembers(struct_id, vstate).size(), + false); + // Check offsets of member decorations + for (auto& decoration : vstate.id_decorations(struct_id)) { + if (SpvDecorationOffset == decoration.dec_type() && + Decoration::kInvalidMember != decoration.struct_member_index()) { + hasOffset[decoration.struct_member_index()] = true; + } + } + // Check also nested structures + bool nestedStructsMissingOffset = false; + for (auto id : getStructMembers(struct_id, SpvOpTypeStruct, vstate)) { + if (isMissingOffsetInStruct(id, vstate)) { + nestedStructsMissingOffset = true; + break; + } + } + return nestedStructsMissingOffset || + !std::all_of(hasOffset.begin(), hasOffset.end(), + [](const bool b) { return b; }); +} + +// Rounds x up to the next alignment. Assumes alignment is a power of two. +uint32_t align(uint32_t x, uint32_t alignment) { + return (x + alignment - 1) & ~(alignment - 1); +} + +// Returns base alignment of struct member. If |roundUp| is true, also +// ensure that structs and arrays are aligned at least to a multiple of 16 +// bytes. +uint32_t getBaseAlignment(uint32_t member_id, bool roundUp, + const LayoutConstraints& inherited, + MemberConstraints& constraints, + ValidationState_t& vstate) { + const auto inst = vstate.FindDef(member_id); + const auto& words = inst->words(); + uint32_t baseAlignment = 0; + switch (inst->opcode()) { + case SpvOpTypeInt: + case SpvOpTypeFloat: + baseAlignment = words[2] / 8; + break; + case SpvOpTypeVector: { + const auto componentId = words[2]; + const auto numComponents = words[3]; + const auto componentAlignment = getBaseAlignment( + componentId, roundUp, inherited, constraints, vstate); + baseAlignment = + componentAlignment * (numComponents == 3 ? 4 : numComponents); + break; + } + case SpvOpTypeMatrix: { + const auto column_type = words[2]; + if (inherited.majorness == kColumnMajor) { + baseAlignment = getBaseAlignment(column_type, roundUp, inherited, + constraints, vstate); + } else { + // A row-major matrix of C columns has a base alignment equal to the + // base alignment of a vector of C matrix components. + const auto num_columns = words[3]; + const auto component_inst = vstate.FindDef(column_type); + const auto component_id = component_inst->words()[2]; + const auto componentAlignment = getBaseAlignment( + component_id, roundUp, inherited, constraints, vstate); + baseAlignment = + componentAlignment * (num_columns == 3 ? 4 : num_columns); + } + } break; + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + baseAlignment = + getBaseAlignment(words[2], roundUp, inherited, constraints, vstate); + if (roundUp) baseAlignment = align(baseAlignment, 16u); + break; + case SpvOpTypeStruct: { + const auto members = getStructMembers(member_id, vstate); + for (uint32_t memberIdx = 0, numMembers = uint32_t(members.size()); + memberIdx < numMembers; ++memberIdx) { + const auto id = members[memberIdx]; + const auto& constraint = + constraints[std::make_pair(member_id, memberIdx)]; + baseAlignment = std::max( + baseAlignment, + getBaseAlignment(id, roundUp, constraint, constraints, vstate)); + } + if (roundUp) baseAlignment = align(baseAlignment, 16u); + break; + } + default: + assert(0); + break; + } + + return baseAlignment; +} + +// Returns size of a struct member. Doesn't include padding at the end of struct +// or array. Assumes that in the struct case, all members have offsets. +uint32_t getSize(uint32_t member_id, bool roundUp, + const LayoutConstraints& inherited, + MemberConstraints& constraints, ValidationState_t& vstate) { + const auto inst = vstate.FindDef(member_id); + const auto& words = inst->words(); + switch (inst->opcode()) { + case SpvOpTypeInt: + case SpvOpTypeFloat: + return getBaseAlignment(member_id, roundUp, inherited, constraints, + vstate); + case SpvOpTypeVector: { + const auto componentId = words[2]; + const auto numComponents = words[3]; + const auto componentSize = + getSize(componentId, roundUp, inherited, constraints, vstate); + const auto size = componentSize * numComponents; + return size; + } + case SpvOpTypeArray: { + const auto sizeInst = vstate.FindDef(words[3]); + if (spvOpcodeIsSpecConstant(sizeInst->opcode())) return 0; + assert(SpvOpConstant == sizeInst->opcode()); + const uint32_t num_elem = sizeInst->words()[3]; + const uint32_t elem_type = words[2]; + const uint32_t elem_size = + getSize(elem_type, roundUp, inherited, constraints, vstate); + // Account for gaps due to alignments in the first N-1 elements, + // then add the size of the last element. + const auto size = + (num_elem - 1) * GetArrayStride(member_id, vstate) + elem_size; + return size; + } + case SpvOpTypeRuntimeArray: + return 0; + case SpvOpTypeMatrix: { + const auto num_columns = words[3]; + if (inherited.majorness == kColumnMajor) { + return num_columns * inherited.matrix_stride; + } else { + // Row major case. + const auto column_type = words[2]; + const auto component_inst = vstate.FindDef(column_type); + const auto num_rows = component_inst->words()[3]; + const auto scalar_elem_type = component_inst->words()[2]; + const uint32_t scalar_elem_size = + getSize(scalar_elem_type, roundUp, inherited, constraints, vstate); + return (num_rows - 1) * inherited.matrix_stride + + num_columns * scalar_elem_size; + } + } + case SpvOpTypeStruct: { + const auto& members = getStructMembers(member_id, vstate); + if (members.empty()) return 0; + const auto lastIdx = uint32_t(members.size() - 1); + const auto& lastMember = members.back(); + uint32_t offset = 0xffffffff; + // Find the offset of the last element and add the size. + for (auto& decoration : vstate.id_decorations(member_id)) { + if (SpvDecorationOffset == decoration.dec_type() && + decoration.struct_member_index() == (int)lastIdx) { + offset = decoration.params()[0]; + } + } + // This check depends on the fact that all members have offsets. This + // has been checked earlier in the flow. + assert(offset != 0xffffffff); + const auto& constraint = constraints[std::make_pair(lastMember, lastIdx)]; + return offset + + getSize(lastMember, roundUp, constraint, constraints, vstate); + } + default: + assert(0); + return 0; + } +} + +// A member is defined to improperly straddle if either of the following are +// true: +// - It is a vector with total size less than or equal to 16 bytes, and has +// Offset decorations placing its first byte at F and its last byte at L, where +// floor(F / 16) != floor(L / 16). +// - It is a vector with total size greater than 16 bytes and has its Offset +// decorations placing its first byte at a non-integer multiple of 16. +bool hasImproperStraddle(uint32_t id, uint32_t offset, + const LayoutConstraints& inherited, + MemberConstraints& constraints, + ValidationState_t& vstate) { + const auto size = getSize(id, false, inherited, constraints, vstate); + const auto F = offset; + const auto L = offset + size - 1; + if (size <= 16) { + if ((F >> 4) != (L >> 4)) return true; + } else { + if (F % 16 != 0) return true; + } + return false; +} + +// Returns true if |offset| satsifies an alignment to |alignment|. In the case +// of |alignment| of zero, the |offset| must also be zero. +bool IsAlignedTo(uint32_t offset, uint32_t alignment) { + if (alignment == 0) return offset == 0; + return 0 == (offset % alignment); +} + +// Returns SPV_SUCCESS if the given struct satisfies standard layout rules for +// Block or BufferBlocks in Vulkan. Otherwise emits a diagnostic and returns +// something other than SPV_SUCCESS. Matrices inherit the specified column +// or row major-ness. +spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str, + const char* decoration_str, bool blockRules, + MemberConstraints& constraints, + ValidationState_t& vstate) { + if (vstate.options()->skip_block_layout) return SPV_SUCCESS; + + auto fail = [&vstate, struct_id, storage_class_str, decoration_str, + blockRules](uint32_t member_idx) -> DiagnosticStream { + DiagnosticStream ds = + std::move(vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(struct_id)) + << "Structure id " << struct_id << " decorated as " + << decoration_str << " for variable in " << storage_class_str + << " storage class must follow standard " + << (blockRules ? "uniform buffer" : "storage buffer") + << " layout rules: member " << member_idx << " "); + return ds; + }; + + const bool relaxed_block_layout = vstate.IsRelaxedBlockLayout(); + const auto& members = getStructMembers(struct_id, vstate); + + // To check for member overlaps, we want to traverse the members in + // offset order. + struct MemberOffsetPair { + uint32_t member; + uint32_t offset; + }; + std::vector member_offsets; + member_offsets.reserve(members.size()); + for (uint32_t memberIdx = 0, numMembers = uint32_t(members.size()); + memberIdx < numMembers; memberIdx++) { + uint32_t offset = 0xffffffff; + for (auto& decoration : vstate.id_decorations(struct_id)) { + if (decoration.struct_member_index() == (int)memberIdx) { + switch (decoration.dec_type()) { + case SpvDecorationOffset: + offset = decoration.params()[0]; + break; + default: + break; + } + } + } + member_offsets.push_back(MemberOffsetPair{memberIdx, offset}); + } + std::stable_sort( + member_offsets.begin(), member_offsets.end(), + [](const MemberOffsetPair& lhs, const MemberOffsetPair& rhs) { + return lhs.offset < rhs.offset; + }); + + // Now scan from lowest offest to highest offset. + uint32_t nextValidOffset = 0; + for (size_t ordered_member_idx = 0; + ordered_member_idx < member_offsets.size(); ordered_member_idx++) { + const auto& member_offset = member_offsets[ordered_member_idx]; + const auto memberIdx = member_offset.member; + const auto offset = member_offset.offset; + auto id = members[member_offset.member]; + const LayoutConstraints& constraint = + constraints[std::make_pair(struct_id, uint32_t(memberIdx))]; + const auto alignment = + getBaseAlignment(id, blockRules, constraint, constraints, vstate); + const auto inst = vstate.FindDef(id); + const auto opcode = inst->opcode(); + const auto size = getSize(id, blockRules, constraint, constraints, vstate); + // Check offset. + if (offset == 0xffffffff) + return fail(memberIdx) << "is missing an Offset decoration"; + if (relaxed_block_layout && opcode == SpvOpTypeVector) { + // In relaxed block layout, the vector offset must be aligned to the + // vector's scalar element type. + const auto componentId = inst->words()[2]; + const auto scalar_alignment = getBaseAlignment( + componentId, blockRules, constraint, constraints, vstate); + if (!IsAlignedTo(offset, scalar_alignment)) { + return fail(memberIdx) + << "at offset " << offset + << " is not aligned to scalar element size " << scalar_alignment; + } + } else { + // Without relaxed block layout, the offset must be divisible by the + // base alignment. + if (!IsAlignedTo(offset, alignment)) { + return fail(memberIdx) + << "at offset " << offset << " is not aligned to " << alignment; + } + } + if (offset < nextValidOffset) + return fail(memberIdx) << "at offset " << offset + << " overlaps previous member ending at offset " + << nextValidOffset - 1; + if (relaxed_block_layout) { + // Check improper straddle of vectors. + if (SpvOpTypeVector == opcode && + hasImproperStraddle(id, offset, constraint, constraints, vstate)) + return fail(memberIdx) + << "is an improperly straddling vector at offset " << offset; + } + // Check struct members recursively. + spv_result_t recursive_status = SPV_SUCCESS; + if (SpvOpTypeStruct == opcode && + SPV_SUCCESS != (recursive_status = + checkLayout(id, storage_class_str, decoration_str, + blockRules, constraints, vstate))) + return recursive_status; + // Check matrix stride. + if (SpvOpTypeMatrix == opcode) { + for (auto& decoration : vstate.id_decorations(id)) { + if (SpvDecorationMatrixStride == decoration.dec_type() && + !IsAlignedTo(decoration.params()[0], alignment)) + return fail(memberIdx) + << "is a matrix with stride " << decoration.params()[0] + << " not satisfying alignment to " << alignment; + } + } + // Check arrays. + if (SpvOpTypeArray == opcode) { + const auto typeId = inst->word(2); + const auto arrayInst = vstate.FindDef(typeId); + if (SpvOpTypeStruct == arrayInst->opcode() && + SPV_SUCCESS != (recursive_status = checkLayout( + typeId, storage_class_str, decoration_str, + blockRules, constraints, vstate))) + return recursive_status; + // Check array stride. + for (auto& decoration : vstate.id_decorations(id)) { + if (SpvDecorationArrayStride == decoration.dec_type() && + !IsAlignedTo(decoration.params()[0], alignment)) + return fail(memberIdx) + << "is an array with stride " << decoration.params()[0] + << " not satisfying alignment to " << alignment; + } + } + nextValidOffset = offset + size; + if (blockRules && (SpvOpTypeArray == opcode || SpvOpTypeStruct == opcode)) { + // Uniform block rules don't permit anything in the padding of a struct + // or array. + nextValidOffset = align(nextValidOffset, alignment); + } + } + return SPV_SUCCESS; +} + +// Returns true if structure id has given decoration. Handles also nested +// structures. +bool hasDecoration(uint32_t struct_id, SpvDecoration decoration, + ValidationState_t& vstate) { + for (auto& dec : vstate.id_decorations(struct_id)) { + if (decoration == dec.dec_type()) return true; + } + for (auto id : getStructMembers(struct_id, SpvOpTypeStruct, vstate)) { + if (hasDecoration(id, decoration, vstate)) { + return true; + } + } + return false; +} + +// Returns true if all ids of given type have a specified decoration. +bool checkForRequiredDecoration(uint32_t struct_id, SpvDecoration decoration, + SpvOp type, ValidationState_t& vstate) { + const auto& members = getStructMembers(struct_id, vstate); + for (size_t memberIdx = 0; memberIdx < members.size(); memberIdx++) { + const auto id = members[memberIdx]; + if (type != vstate.FindDef(id)->opcode()) continue; + bool found = false; + for (auto& dec : vstate.id_decorations(id)) { + if (decoration == dec.dec_type()) found = true; + } + for (auto& dec : vstate.id_decorations(struct_id)) { + if (decoration == dec.dec_type() && + (int)memberIdx == dec.struct_member_index()) { + found = true; + } + } + if (!found) { + return false; + } + } + for (auto id : getStructMembers(struct_id, SpvOpTypeStruct, vstate)) { + if (!checkForRequiredDecoration(id, decoration, type, vstate)) { + return false; + } + } + return true; +} + +spv_result_t CheckLinkageAttrOfFunctions(ValidationState_t& vstate) { + for (const auto& function : vstate.functions()) { + if (function.block_count() == 0u) { + // A function declaration (an OpFunction with no basic blocks), must have + // a Linkage Attributes Decoration with the Import Linkage Type. + if (!hasImportLinkageAttribute(function.id(), vstate)) { + return vstate.diag(SPV_ERROR_INVALID_BINARY, + vstate.FindDef(function.id())) + << "Function declaration (id " << function.id() + << ") must have a LinkageAttributes decoration with the Import " + "Linkage type."; + } + } else { + if (hasImportLinkageAttribute(function.id(), vstate)) { + return vstate.diag(SPV_ERROR_INVALID_BINARY, + vstate.FindDef(function.id())) + << "Function definition (id " << function.id() + << ") may not be decorated with Import Linkage type."; + } + } + } + return SPV_SUCCESS; +} + +// Checks whether an imported variable is initialized by this module. +spv_result_t CheckImportedVariableInitialization(ValidationState_t& vstate) { + // According the SPIR-V Spec 2.16.1, it is illegal to initialize an imported + // variable. This means that a module-scope OpVariable with initialization + // value cannot be marked with the Import Linkage Type (import type id = 1). + for (auto global_var_id : vstate.global_vars()) { + // Initializer is an optional argument for OpVariable. If initializer + // is present, the instruction will have 5 words. + auto variable_instr = vstate.FindDef(global_var_id); + if (variable_instr->words().size() == 5u && + hasImportLinkageAttribute(global_var_id, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, variable_instr) + << "A module-scope OpVariable with initialization value " + "cannot be marked with the Import Linkage Type."; + } + } + return SPV_SUCCESS; +} + +// Checks whether a builtin variable is valid. +spv_result_t CheckBuiltInVariable(uint32_t var_id, ValidationState_t& vstate) { + const auto& decorations = vstate.id_decorations(var_id); + for (const auto& d : decorations) { + if (spvIsVulkanEnv(vstate.context()->target_env)) { + if (d.dec_type() == SpvDecorationLocation || + d.dec_type() == SpvDecorationComponent) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id)) + << "A BuiltIn variable (id " << var_id + << ") cannot have any Location or Component decorations"; + } + } + } + return SPV_SUCCESS; +} + +// Checks whether proper decorations have been appied to the entry points. +spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) { + for (uint32_t entry_point : vstate.entry_points()) { + const auto& descs = vstate.entry_point_descriptions(entry_point); + int num_builtin_inputs = 0; + int num_builtin_outputs = 0; + for (const auto& desc : descs) { + for (auto interface : desc.interfaces) { + Instruction* var_instr = vstate.FindDef(interface); + if (!var_instr || SpvOpVariable != var_instr->opcode()) { + return vstate.diag(SPV_ERROR_INVALID_ID, var_instr) + << "Interfaces passed to OpEntryPoint must be of type " + "OpTypeVariable. Found Op" + << spvOpcodeString(var_instr->opcode()) << "."; + } + const SpvStorageClass storage_class = + var_instr->GetOperandAs(2); + if (storage_class != SpvStorageClassInput && + storage_class != SpvStorageClassOutput) { + return vstate.diag(SPV_ERROR_INVALID_ID, var_instr) + << "OpEntryPoint interfaces must be OpVariables with " + "Storage Class of Input(1) or Output(3). Found Storage " + "Class " + << storage_class << " for Entry Point id " << entry_point + << "."; + } + + const uint32_t ptr_id = var_instr->word(1); + Instruction* ptr_instr = vstate.FindDef(ptr_id); + // It is guaranteed (by validator ID checks) that ptr_instr is + // OpTypePointer. Word 3 of this instruction is the type being pointed + // to. + const uint32_t type_id = ptr_instr->word(3); + Instruction* type_instr = vstate.FindDef(type_id); + if (type_instr && SpvOpTypeStruct == type_instr->opcode() && + isBuiltInStruct(type_id, vstate)) { + if (storage_class == SpvStorageClassInput) ++num_builtin_inputs; + if (storage_class == SpvStorageClassOutput) ++num_builtin_outputs; + if (num_builtin_inputs > 1 || num_builtin_outputs > 1) break; + if (auto error = CheckBuiltInVariable(interface, vstate)) + return error; + } else if (isBuiltInVar(interface, vstate)) { + if (auto error = CheckBuiltInVariable(interface, vstate)) + return error; + } + } + if (num_builtin_inputs > 1 || num_builtin_outputs > 1) { + return vstate.diag(SPV_ERROR_INVALID_BINARY, + vstate.FindDef(entry_point)) + << "There must be at most one object per Storage Class that can " + "contain a structure type containing members decorated with " + "BuiltIn, consumed per entry-point. Entry Point id " + << entry_point << " does not meet this requirement."; + } + // The LinkageAttributes Decoration cannot be applied to functions + // targeted by an OpEntryPoint instruction + for (auto& decoration : vstate.id_decorations(entry_point)) { + if (SpvDecorationLinkageAttributes == decoration.dec_type()) { + const char* linkage_name = + reinterpret_cast(&decoration.params()[0]); + return vstate.diag(SPV_ERROR_INVALID_BINARY, + vstate.FindDef(entry_point)) + << "The LinkageAttributes Decoration (Linkage name: " + << linkage_name << ") cannot be applied to function id " + << entry_point + << " because it is targeted by an OpEntryPoint instruction."; + } + } + } + } + return SPV_SUCCESS; +} + +spv_result_t CheckDescriptorSetArrayOfArrays(ValidationState_t& vstate) { + for (const auto& inst : vstate.ordered_instructions()) { + if (SpvOpVariable != inst.opcode()) continue; + + // Verify this variable is a DescriptorSet + bool has_descriptor_set = false; + for (const auto& decoration : vstate.id_decorations(inst.id())) { + if (SpvDecorationDescriptorSet == decoration.dec_type()) { + has_descriptor_set = true; + break; + } + } + if (!has_descriptor_set) continue; + + const auto* ptrInst = vstate.FindDef(inst.word(1)); + assert(SpvOpTypePointer == ptrInst->opcode()); + + // Check for a first level array + const auto typePtr = vstate.FindDef(ptrInst->word(3)); + if (SpvOpTypeRuntimeArray != typePtr->opcode() && + SpvOpTypeArray != typePtr->opcode()) { + continue; + } + + // Check for a second level array + const auto secondaryTypePtr = vstate.FindDef(typePtr->word(2)); + if (SpvOpTypeRuntimeArray == secondaryTypePtr->opcode() || + SpvOpTypeArray == secondaryTypePtr->opcode()) { + return vstate.diag(SPV_ERROR_INVALID_ID, &inst) + << "Only a single level of array is allowed for descriptor " + "set variables"; + } + } + return SPV_SUCCESS; +} + +// Load |constraints| with all the member constraints for structs contained +// within the given array type. +void ComputeMemberConstraintsForArray(MemberConstraints* constraints, + uint32_t array_id, + const LayoutConstraints& inherited, + ValidationState_t& vstate); + +// Load |constraints| with all the member constraints for the given struct, +// and all its contained structs. +void ComputeMemberConstraintsForStruct(MemberConstraints* constraints, + uint32_t struct_id, + const LayoutConstraints& inherited, + ValidationState_t& vstate) { + assert(constraints); + const auto& members = getStructMembers(struct_id, vstate); + for (uint32_t memberIdx = 0, numMembers = uint32_t(members.size()); + memberIdx < numMembers; memberIdx++) { + LayoutConstraints& constraint = + (*constraints)[std::make_pair(struct_id, memberIdx)]; + constraint = inherited; + for (auto& decoration : vstate.id_decorations(struct_id)) { + if (decoration.struct_member_index() == (int)memberIdx) { + switch (decoration.dec_type()) { + case SpvDecorationRowMajor: + constraint.majorness = kRowMajor; + break; + case SpvDecorationColMajor: + constraint.majorness = kColumnMajor; + break; + case SpvDecorationMatrixStride: + constraint.matrix_stride = decoration.params()[0]; + break; + default: + break; + } + } + } + + // Now recurse + auto member_type_id = members[memberIdx]; + const auto member_type_inst = vstate.FindDef(member_type_id); + const auto opcode = member_type_inst->opcode(); + switch (opcode) { + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + ComputeMemberConstraintsForArray(constraints, member_type_id, inherited, + vstate); + break; + case SpvOpTypeStruct: + ComputeMemberConstraintsForStruct(constraints, member_type_id, + inherited, vstate); + break; + default: + break; + } + } +} + +void ComputeMemberConstraintsForArray(MemberConstraints* constraints, + uint32_t array_id, + const LayoutConstraints& inherited, + ValidationState_t& vstate) { + assert(constraints); + auto elem_type_id = vstate.FindDef(array_id)->words()[2]; + const auto elem_type_inst = vstate.FindDef(elem_type_id); + const auto opcode = elem_type_inst->opcode(); + switch (opcode) { + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + ComputeMemberConstraintsForArray(constraints, elem_type_id, inherited, + vstate); + break; + case SpvOpTypeStruct: + ComputeMemberConstraintsForStruct(constraints, elem_type_id, inherited, + vstate); + break; + default: + break; + } +} + +spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) { + for (const auto& inst : vstate.ordered_instructions()) { + const auto& words = inst.words(); + if (SpvOpVariable == inst.opcode()) { + // For storage class / decoration combinations, see Vulkan 14.5.4 "Offset + // and Stride Assignment". + const auto storageClass = words[3]; + const bool uniform = storageClass == SpvStorageClassUniform; + const bool push_constant = storageClass == SpvStorageClassPushConstant; + const bool storage_buffer = storageClass == SpvStorageClassStorageBuffer; + if (uniform || push_constant || storage_buffer) { + const auto ptrInst = vstate.FindDef(words[1]); + assert(SpvOpTypePointer == ptrInst->opcode()); + const auto id = ptrInst->words()[3]; + if (SpvOpTypeStruct != vstate.FindDef(id)->opcode()) continue; + MemberConstraints constraints; + ComputeMemberConstraintsForStruct(&constraints, id, LayoutConstraints(), + vstate); + // Prepare for messages + const char* sc_str = + uniform ? "Uniform" + : (push_constant ? "PushConstant" : "StorageBuffer"); + for (const auto& dec : vstate.id_decorations(id)) { + const bool blockDeco = SpvDecorationBlock == dec.dec_type(); + const bool bufferDeco = SpvDecorationBufferBlock == dec.dec_type(); + const bool blockRules = uniform && blockDeco; + const bool bufferRules = (uniform && bufferDeco) || + (push_constant && blockDeco) || + (storage_buffer && blockDeco); + if (blockRules || bufferRules) { + const char* deco_str = blockDeco ? "Block" : "BufferBlock"; + spv_result_t recursive_status = SPV_SUCCESS; + if (isMissingOffsetInStruct(id, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "Structure id " << id << " decorated as " << deco_str + << " must be explicitly laid out with Offset decorations."; + } else if (hasDecoration(id, SpvDecorationGLSLShared, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "Structure id " << id << " decorated as " << deco_str + << " must not use GLSLShared decoration."; + } else if (hasDecoration(id, SpvDecorationGLSLPacked, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "Structure id " << id << " decorated as " << deco_str + << " must not use GLSLPacked decoration."; + } else if (!checkForRequiredDecoration(id, SpvDecorationArrayStride, + SpvOpTypeArray, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "Structure id " << id << " decorated as " << deco_str + << " must be explicitly laid out with ArrayStride " + "decorations."; + } else if (!checkForRequiredDecoration(id, + SpvDecorationMatrixStride, + SpvOpTypeMatrix, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "Structure id " << id << " decorated as " << deco_str + << " must be explicitly laid out with MatrixStride " + "decorations."; + } else if (blockRules && + (SPV_SUCCESS != (recursive_status = checkLayout( + id, sc_str, deco_str, true, + constraints, vstate)))) { + return recursive_status; + } else if (bufferRules && + (SPV_SUCCESS != (recursive_status = checkLayout( + id, sc_str, deco_str, false, + constraints, vstate)))) { + return recursive_status; + } + } + } + } + } + } + return SPV_SUCCESS; +} + +} // namespace + +// Validates that decorations have been applied properly. +spv_result_t ValidateDecorations(ValidationState_t& vstate) { + if (auto error = CheckImportedVariableInitialization(vstate)) return error; + if (auto error = CheckDecorationsOfEntryPoints(vstate)) return error; + if (auto error = CheckDecorationsOfBuffers(vstate)) return error; + if (auto error = CheckLinkageAttrOfFunctions(vstate)) return error; + if (auto error = CheckDescriptorSetArrayOfArrays(vstate)) return error; + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_derivatives.cpp b/3rdparty/spirv-tools/source/val/validate_derivatives.cpp similarity index 63% rename from 3rdparty/spirv-tools/source/validate_derivatives.cpp rename to 3rdparty/spirv-tools/source/val/validate_derivatives.cpp index 299f027f8..0e0dbbe3d 100644 --- a/3rdparty/spirv-tools/source/validate_derivatives.cpp +++ b/3rdparty/spirv-tools/source/val/validate_derivatives.cpp @@ -14,20 +14,22 @@ // Validates correctness of derivative SPIR-V instructions. -#include "validate.h" +#include "source/val/validate.h" -#include "diagnostic.h" -#include "opcode.h" -#include "val/instruction.h" -#include "val/validation_state.h" +#include -namespace libspirv { +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { // Validates correctness of derivative instructions. -spv_result_t DerivativesPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - const uint32_t result_type = inst->type_id; +spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); switch (opcode) { case SpvOpDPdx: @@ -40,22 +42,24 @@ spv_result_t DerivativesPass(ValidationState_t& _, case SpvOpDPdyCoarse: case SpvOpFwidthCoarse: { if (!_.IsFloatScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type to be float scalar or vector type: " << spvOpcodeString(opcode); } const uint32_t p_type = _.GetOperandTypeId(inst, 2); if (p_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected P type and Result Type to be the same: " << spvOpcodeString(opcode); } - _.current_function().RegisterExecutionModelLimitation( - SpvExecutionModelFragment, std::string( - "Derivative instructions require Fragment execution model: ") + - spvOpcodeString(opcode)); + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + std::string("Derivative instructions require Fragment execution " + "model: ") + + spvOpcodeString(opcode)); break; } @@ -66,4 +70,5 @@ spv_result_t DerivativesPass(ValidationState_t& _, return SPV_SUCCESS; } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_execution_limitations.cpp b/3rdparty/spirv-tools/source/val/validate_execution_limitations.cpp new file mode 100644 index 000000000..d44930770 --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_execution_limitations.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include "source/val/function.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +spv_result_t ValidateExecutionLimitations(ValidationState_t& _, + const Instruction* inst) { + if (inst->opcode() != SpvOpFunction) { + return SPV_SUCCESS; + } + + const auto func = _.function(inst->id()); + if (!func) { + return _.diag(SPV_ERROR_INTERNAL, inst) + << "Internal error: missing function id " << inst->id() << "."; + } + + for (uint32_t entry_id : _.FunctionEntryPoints(inst->id())) { + const auto* models = _.GetExecutionModels(entry_id); + if (models) { + if (models->empty()) { + return _.diag(SPV_ERROR_INTERNAL, inst) + << "Internal error: empty execution models for function id " + << entry_id << "."; + } + for (const auto model : *models) { + std::string reason; + if (!func->IsCompatibleWithExecutionModel(model, &reason)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpEntryPoint Entry Point '" << _.getIdName(entry_id) + << "'s callgraph contains function " + << _.getIdName(inst->id()) + << ", which cannot be used with the current execution " + "model:\n" + << reason; + } + } + } + } + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_ext_inst.cpp b/3rdparty/spirv-tools/source/val/validate_ext_inst.cpp similarity index 83% rename from 3rdparty/spirv-tools/source/validate_ext_inst.cpp rename to 3rdparty/spirv-tools/source/val/validate_ext_inst.cpp index e74dbdc2d..eb3427090 100644 --- a/3rdparty/spirv-tools/source/validate_ext_inst.cpp +++ b/3rdparty/spirv-tools/source/val/validate_ext_inst.cpp @@ -14,20 +14,21 @@ // Validates correctness of ExtInst SPIR-V instructions. -#include "validate.h" +#include "source/val/validate.h" #include +#include +#include -#include "latest_version_glsl_std_450_header.h" -#include "latest_version_opencl_std_header.h" - -#include "diagnostic.h" -#include "opcode.h" -#include "val/instruction.h" -#include "val/validation_state.h" - -namespace libspirv { +#include "source/diagnostic.h" +#include "source/latest_version_glsl_std_450_header.h" +#include "source/latest_version_opencl_std_header.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" +namespace spvtools { +namespace val { namespace { uint32_t GetSizeTBitWidth(const ValidationState_t& _) { @@ -41,18 +42,17 @@ uint32_t GetSizeTBitWidth(const ValidationState_t& _) { } // anonymous namespace // Validates correctness of ExtInst instructions. -spv_result_t ExtInstPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - const uint32_t result_type = inst->type_id; - const uint32_t num_operands = inst->num_operands; +spv_result_t ExtInstPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); + const uint32_t num_operands = static_cast(inst->operands().size()); if (opcode != SpvOpExtInst) return SPV_SUCCESS; - const uint32_t ext_inst_set = inst->words[3]; - const uint32_t ext_inst_index = inst->words[4]; + const uint32_t ext_inst_set = inst->word(3); + const uint32_t ext_inst_index = inst->word(4); const spv_ext_inst_type_t ext_inst_type = - spv_ext_inst_type_t(inst->ext_inst_type); + spv_ext_inst_type_t(inst->ext_inst_type()); auto ext_inst_name = [&_, ext_inst_set, ext_inst_type, ext_inst_index]() { spv_ext_inst_desc desc = nullptr; @@ -100,7 +100,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450NMax: case GLSLstd450NClamp: { if (!_.IsFloatScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar or vector type"; } @@ -109,7 +109,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, ++operand_index) { const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); if (result_type != operand_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected types of all operands to be equal to Result " "Type"; @@ -130,7 +130,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450FindUMsb: case GLSLstd450FindSMsb: { if (!_.IsIntScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be an int scalar or vector type"; } @@ -142,20 +142,20 @@ spv_result_t ExtInstPass(ValidationState_t& _, ++operand_index) { const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); if (!_.IsIntScalarOrVectorType(operand_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected all operands to be int scalars or vectors"; } if (result_type_dimension != _.GetDimension(operand_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected all operands to have the same dimension as " << "Result Type"; } if (result_type_bit_width != _.GetBitWidth(operand_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected all operands to have the same bit width as " << "Result Type"; @@ -164,7 +164,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (ext_inst_key == GLSLstd450FindUMsb || ext_inst_key == GLSLstd450FindSMsb) { if (result_type_bit_width != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "this instruction is currently limited to 32-bit width " << "components"; @@ -195,7 +195,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450Atan2: case GLSLstd450Pow: { if (!_.IsFloatScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a 16 or 32-bit scalar or " "vector float type"; @@ -203,7 +203,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t result_type_bit_width = _.GetBitWidth(result_type); if (result_type_bit_width != 16 && result_type_bit_width != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a 16 or 32-bit scalar or " "vector float type"; @@ -213,7 +213,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, ++operand_index) { const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); if (result_type != operand_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected types of all operands to be equal to Result " "Type"; @@ -231,13 +231,13 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (!_.GetMatrixTypeInfo(x_type, &num_rows, &num_cols, &col_type, &component_type) || num_rows != num_cols) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X to be a square matrix"; } if (result_type != component_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X component type to be equal to " << "Result Type"; @@ -253,14 +253,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type, &component_type) || num_rows != num_cols) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a square matrix"; } const uint32_t x_type = _.GetOperandTypeId(inst, 4); if (result_type != x_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X type to be equal to Result Type"; } @@ -269,7 +269,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450Modf: { if (!_.IsFloatScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or vector float type"; } @@ -278,7 +278,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t i_type = _.GetOperandTypeId(inst, 5); if (x_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X type to be equal to Result Type"; } @@ -286,13 +286,13 @@ spv_result_t ExtInstPass(ValidationState_t& _, uint32_t i_storage_class = 0; uint32_t i_data_type = 0; if (!_.GetPointerTypeInfo(i_type, &i_data_type, &i_storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand I to be a pointer"; } if (i_data_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand I data type to be equal to Result Type"; } @@ -306,7 +306,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, result_types.size() != 2 || !_.IsFloatScalarOrVectorType(result_types[0]) || result_types[1] != result_types[0]) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a struct with two identical " << "scalar or vector float type members"; @@ -314,7 +314,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t x_type = _.GetOperandTypeId(inst, 4); if (x_type != result_types[0]) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X type to be equal to members of " << "Result Type struct"; @@ -324,7 +324,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450Frexp: { if (!_.IsFloatScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or vector float type"; } @@ -333,7 +333,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t exp_type = _.GetOperandTypeId(inst, 5); if (x_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X type to be equal to Result Type"; } @@ -342,21 +342,28 @@ spv_result_t ExtInstPass(ValidationState_t& _, uint32_t exp_data_type = 0; if (!_.GetPointerTypeInfo(exp_type, &exp_data_type, &exp_storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Exp to be a pointer"; } if (!_.IsIntScalarOrVectorType(exp_data_type) || - _.GetBitWidth(exp_data_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + (!_.HasExtension(kSPV_AMD_gpu_shader_int16) && + _.GetBitWidth(exp_data_type) != 32) || + (_.HasExtension(kSPV_AMD_gpu_shader_int16) && + _.GetBitWidth(exp_data_type) != 16 && + _.GetBitWidth(exp_data_type) != 32)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " - << "expected operand Exp data type to be a 32-bit int scalar " - << "or vector type"; + << "expected operand Exp data type to be a " + << (_.HasExtension(kSPV_AMD_gpu_shader_int16) + ? "16-bit or 32-bit " + : "32-bit ") + << "int scalar or vector type"; } if (_.GetDimension(result_type) != _.GetDimension(exp_data_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Exp data type to have the same component " << "number as Result Type"; @@ -367,7 +374,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450Ldexp: { if (!_.IsFloatScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or vector float type"; } @@ -376,20 +383,20 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t exp_type = _.GetOperandTypeId(inst, 5); if (x_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X type to be equal to Result Type"; } if (!_.IsIntScalarOrVectorType(exp_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Exp to be a 32-bit int scalar " << "or vector type"; } if (_.GetDimension(result_type) != _.GetDimension(exp_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Exp to have the same component " << "number as Result Type"; @@ -404,20 +411,27 @@ spv_result_t ExtInstPass(ValidationState_t& _, result_types.size() != 2 || !_.IsFloatScalarOrVectorType(result_types[0]) || !_.IsIntScalarOrVectorType(result_types[1]) || - _.GetBitWidth(result_types[1]) != 32 || + (!_.HasExtension(kSPV_AMD_gpu_shader_int16) && + _.GetBitWidth(result_types[1]) != 32) || + (_.HasExtension(kSPV_AMD_gpu_shader_int16) && + _.GetBitWidth(result_types[1]) != 16 && + _.GetBitWidth(result_types[1]) != 32) || _.GetDimension(result_types[0]) != _.GetDimension(result_types[1])) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a struct with two members, " - << "first member a float scalar or vector, second member " - << "a 32-bit int scalar or vector with the same number of " + << "first member a float scalar or vector, second member a " + << (_.HasExtension(kSPV_AMD_gpu_shader_int16) + ? "16-bit or 32-bit " + : "32-bit ") + << "int scalar or vector with the same number of " << "components as the first member"; } const uint32_t x_type = _.GetOperandTypeId(inst, 4); if (x_type != result_types[0]) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X type to be equal to the first member " << "of Result Type struct"; @@ -429,7 +443,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450PackUnorm4x8: { if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be 32-bit int scalar type"; } @@ -437,7 +451,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t v_type = _.GetOperandTypeId(inst, 4); if (!_.IsFloatVectorType(v_type) || _.GetDimension(v_type) != 4 || _.GetBitWidth(v_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand V to be a 32-bit float vector of size 4"; } @@ -449,7 +463,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450PackHalf2x16: { if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be 32-bit int scalar type"; } @@ -457,7 +471,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t v_type = _.GetOperandTypeId(inst, 4); if (!_.IsFloatVectorType(v_type) || _.GetDimension(v_type) != 2 || _.GetBitWidth(v_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand V to be a 32-bit float vector of size 2"; } @@ -467,7 +481,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450PackDouble2x32: { if (!_.IsFloatScalarType(result_type) || _.GetBitWidth(result_type) != 64) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be 64-bit float scalar type"; } @@ -475,7 +489,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t v_type = _.GetOperandTypeId(inst, 4); if (!_.IsIntVectorType(v_type) || _.GetDimension(v_type) != 2 || _.GetBitWidth(v_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand V to be a 32-bit int vector of size 2"; } @@ -487,7 +501,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (!_.IsFloatVectorType(result_type) || _.GetDimension(result_type) != 4 || _.GetBitWidth(result_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a 32-bit float vector of size " "4"; @@ -495,7 +509,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t v_type = _.GetOperandTypeId(inst, 4); if (!_.IsIntScalarType(v_type) || _.GetBitWidth(v_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P to be a 32-bit int scalar"; } @@ -508,7 +522,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (!_.IsFloatVectorType(result_type) || _.GetDimension(result_type) != 2 || _.GetBitWidth(result_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a 32-bit float vector of size " "2"; @@ -516,7 +530,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t v_type = _.GetOperandTypeId(inst, 4); if (!_.IsIntScalarType(v_type) || _.GetBitWidth(v_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P to be a 32-bit int scalar"; } @@ -527,7 +541,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (!_.IsIntVectorType(result_type) || _.GetDimension(result_type) != 2 || _.GetBitWidth(result_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a 32-bit int vector of size " "2"; @@ -535,7 +549,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t v_type = _.GetOperandTypeId(inst, 4); if (!_.IsFloatScalarType(v_type) || _.GetBitWidth(v_type) != 64) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand V to be a 64-bit float scalar"; } @@ -544,20 +558,20 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450Length: { if (!_.IsFloatScalarType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar type"; } const uint32_t x_type = _.GetOperandTypeId(inst, 4); if (!_.IsFloatScalarOrVectorType(x_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X to be of float scalar or vector type"; } if (result_type != _.GetComponentType(x_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X component type to be equal to Result " "Type"; @@ -567,20 +581,20 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450Distance: { if (!_.IsFloatScalarType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar type"; } const uint32_t p0_type = _.GetOperandTypeId(inst, 4); if (!_.IsFloatScalarOrVectorType(p0_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P0 to be of float scalar or vector type"; } if (result_type != _.GetComponentType(p0_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P0 component type to be equal to " << "Result Type"; @@ -588,20 +602,20 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t p1_type = _.GetOperandTypeId(inst, 5); if (!_.IsFloatScalarOrVectorType(p1_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P1 to be of float scalar or vector type"; } if (result_type != _.GetComponentType(p1_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P1 component type to be equal to " << "Result Type"; } if (_.GetDimension(p0_type) != _.GetDimension(p1_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operands P0 and P1 to have the same number of " << "components"; @@ -611,13 +625,13 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450Cross: { if (!_.IsFloatVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float vector type"; } if (_.GetDimension(result_type) != 3) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to have 3 components"; } @@ -626,13 +640,13 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t y_type = _.GetOperandTypeId(inst, 5); if (x_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X type to be equal to Result Type"; } if (y_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Y type to be equal to Result Type"; } @@ -641,7 +655,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450Refract: { if (!_.IsFloatScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar or vector type"; } @@ -651,23 +665,21 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t eta_type = _.GetOperandTypeId(inst, 6); if (result_type != i_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand I to be of type equal to Result Type"; } if (result_type != n_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand N to be of type equal to Result Type"; } - const uint32_t eta_type_bit_width = _.GetBitWidth(eta_type); - if (!_.IsFloatScalarType(eta_type) || - (eta_type_bit_width != 16 && eta_type_bit_width != 32)) { - return _.diag(SPV_ERROR_INVALID_DATA) + if (!_.IsFloatScalarType(eta_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " - << "expected operand Eta to be a 16 or 32-bit float scalar"; + << "expected operand Eta to be a float scalar"; } break; } @@ -676,14 +688,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, case GLSLstd450InterpolateAtSample: case GLSLstd450InterpolateAtOffset: { if (!_.HasCapability(SpvCapabilityInterpolationFunction)) { - return _.diag(SPV_ERROR_INVALID_CAPABILITY) + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) << ext_inst_name() << " requires capability InterpolationFunction"; } if (!_.IsFloatScalarOrVectorType(result_type) || _.GetBitWidth(result_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a 32-bit float scalar " << "or vector type"; @@ -694,19 +706,19 @@ spv_result_t ExtInstPass(ValidationState_t& _, uint32_t interpolant_data_type = 0; if (!_.GetPointerTypeInfo(interpolant_type, &interpolant_data_type, &interpolant_storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Interpolant to be a pointer"; } if (result_type != interpolant_data_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Interpolant data type to be equal to Result Type"; } if (interpolant_storage_class != SpvStorageClassInput) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Interpolant storage class to be Input"; } @@ -715,7 +727,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t sample_type = _.GetOperandTypeId(inst, 5); if (!_.IsIntScalarType(sample_type) || _.GetBitWidth(sample_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Sample to be 32-bit integer"; } @@ -726,26 +738,27 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (!_.IsFloatVectorType(offset_type) || _.GetDimension(offset_type) != 2 || _.GetBitWidth(offset_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Offset to be a vector of 2 32-bit floats"; } } - _.current_function().RegisterExecutionModelLimitation( - SpvExecutionModelFragment, - ext_inst_name() + - std::string(" requires Fragment execution model")); + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + ext_inst_name() + + std::string(" requires Fragment execution model")); break; } case GLSLstd450IMix: { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Extended instruction GLSLstd450IMix is not supported"; } case GLSLstd450Bad: { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Encountered extended instruction GLSLstd450Bad"; } @@ -852,14 +865,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Smoothstep: case OpenCLLIB::Sign: { if (!_.IsFloatScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar or vector type"; } const uint32_t num_components = _.GetDimension(result_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or a vector with 2, " "3, 4, 8 or 16 components"; @@ -869,7 +882,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, ++operand_index) { const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); if (result_type != operand_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected types of all operands to be equal to Result " "Type"; @@ -883,14 +896,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Sincos: case OpenCLLIB::Remquo: { if (!_.IsFloatScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar or vector type"; } const uint32_t num_components = _.GetDimension(result_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or a vector with 2, " "3, 4, 8 or 16 components"; @@ -899,7 +912,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, uint32_t operand_index = 4; const uint32_t x_type = _.GetOperandTypeId(inst, operand_index++); if (result_type != x_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected type of operand X to be equal to Result Type"; } @@ -907,7 +920,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (ext_inst_key == OpenCLLIB::Remquo) { const uint32_t y_type = _.GetOperandTypeId(inst, operand_index++); if (result_type != y_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected type of operand Y to be equal to Result Type"; } @@ -917,7 +930,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, uint32_t p_storage_class = 0; uint32_t p_data_type = 0; if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected the last operand to be a pointer"; } @@ -926,14 +939,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, p_storage_class != SpvStorageClassCrossWorkgroup && p_storage_class != SpvStorageClassWorkgroup && p_storage_class != SpvStorageClassFunction) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected storage class of the pointer to be Generic, " "CrossWorkgroup, Workgroup or Function"; } if (result_type != p_data_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected data type of the pointer to be equal to Result " "Type"; @@ -944,14 +957,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Frexp: case OpenCLLIB::Lgamma_r: { if (!_.IsFloatScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar or vector type"; } const uint32_t num_components = _.GetDimension(result_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or a vector with 2, " "3, 4, 8 or 16 components"; @@ -959,7 +972,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t x_type = _.GetOperandTypeId(inst, 4); if (result_type != x_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected type of operand X to be equal to Result Type"; } @@ -968,7 +981,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, uint32_t p_storage_class = 0; uint32_t p_data_type = 0; if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected the last operand to be a pointer"; } @@ -977,7 +990,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, p_storage_class != SpvStorageClassCrossWorkgroup && p_storage_class != SpvStorageClassWorkgroup && p_storage_class != SpvStorageClassFunction) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected storage class of the pointer to be Generic, " "CrossWorkgroup, Workgroup or Function"; @@ -985,14 +998,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (!_.IsIntScalarOrVectorType(p_data_type) || _.GetBitWidth(p_data_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected data type of the pointer to be a 32-bit int " "scalar or vector type"; } if (_.GetDimension(p_data_type) != num_components) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected data type of the pointer to have the same number " "of components as Result Type"; @@ -1003,7 +1016,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Ilogb: { if (!_.IsIntScalarOrVectorType(result_type) || _.GetBitWidth(result_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a 32-bit int scalar or vector " "type"; @@ -1011,7 +1024,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t num_components = _.GetDimension(result_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or a vector with 2, " "3, 4, 8 or 16 components"; @@ -1019,13 +1032,13 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t x_type = _.GetOperandTypeId(inst, 4); if (!_.IsFloatScalarOrVectorType(x_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X to be a float scalar or vector"; } if (_.GetDimension(x_type) != num_components) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X to have the same number of components " "as Result Type"; @@ -1037,14 +1050,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Pown: case OpenCLLIB::Rootn: { if (!_.IsFloatScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar or vector type"; } const uint32_t num_components = _.GetDimension(result_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or a vector with 2, " "3, 4, 8 or 16 components"; @@ -1052,7 +1065,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t x_type = _.GetOperandTypeId(inst, 4); if (result_type != x_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected type of operand X to be equal to Result Type"; } @@ -1060,13 +1073,13 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t exp_type = _.GetOperandTypeId(inst, 5); if (!_.IsIntScalarOrVectorType(exp_type) || _.GetBitWidth(exp_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected the exponent to be a 32-bit int scalar or vector"; } if (_.GetDimension(exp_type) != num_components) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected the exponent to have the same number of " "components as Result Type"; @@ -1076,14 +1089,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Nan: { if (!_.IsFloatScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar or vector type"; } const uint32_t num_components = _.GetDimension(result_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or a vector with 2, " "3, 4, 8 or 16 components"; @@ -1091,20 +1104,20 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t nancode_type = _.GetOperandTypeId(inst, 4); if (!_.IsIntScalarOrVectorType(nancode_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Nancode to be an int scalar or vector type"; } if (_.GetDimension(nancode_type) != num_components) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Nancode to have the same number of components as " "Result Type"; } if (_.GetBitWidth(result_type) != _.GetBitWidth(nancode_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Nancode to have the same bit width as Result " "Type"; @@ -1141,14 +1154,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::UMul_hi: case OpenCLLIB::UMad_hi: { if (!_.IsIntScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be an int scalar or vector type"; } const uint32_t num_components = _.GetDimension(result_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or a vector with 2, " "3, 4, 8 or 16 components"; @@ -1158,7 +1171,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, ++operand_index) { const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); if (result_type != operand_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected types of all operands to be equal to Result " "Type"; @@ -1170,7 +1183,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::U_Upsample: case OpenCLLIB::S_Upsample: { if (!_.IsIntScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be an int scalar or vector " "type"; @@ -1179,7 +1192,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t result_num_components = _.GetDimension(result_type); if (result_num_components > 4 && result_num_components != 8 && result_num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or a vector with 2, " "3, 4, 8 or 16 components"; @@ -1188,7 +1201,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t result_bit_width = _.GetBitWidth(result_type); if (result_bit_width != 16 && result_bit_width != 32 && result_bit_width != 64) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected bit width of Result Type components to be 16, 32 " "or 64"; @@ -1198,20 +1211,20 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t lo_type = _.GetOperandTypeId(inst, 5); if (hi_type != lo_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Hi and Lo operands to have the same type"; } if (result_num_components != _.GetDimension(hi_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Hi and Lo operands to have the same number of " "components as Result Type"; } if (result_bit_width != 2 * _.GetBitWidth(hi_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected bit width of components of Hi and Lo operands to " "be half of the bit width of components of Result Type"; @@ -1225,7 +1238,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::UMul24: { if (!_.IsIntScalarOrVectorType(result_type) || _.GetBitWidth(result_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a 32-bit int scalar or vector " "type"; @@ -1233,7 +1246,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t num_components = _.GetDimension(result_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or a vector with 2, " "3, 4, 8 or 16 components"; @@ -1243,7 +1256,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, ++operand_index) { const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); if (result_type != operand_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected types of all operands to be equal to Result " "Type"; @@ -1254,14 +1267,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Cross: { if (!_.IsFloatVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float vector type"; } const uint32_t num_components = _.GetDimension(result_type); if (num_components != 3 && num_components != 4) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to have 3 or 4 components"; } @@ -1270,13 +1283,13 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t y_type = _.GetOperandTypeId(inst, 5); if (x_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X type to be equal to Result Type"; } if (y_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Y type to be equal to Result Type"; } @@ -1286,27 +1299,27 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Distance: case OpenCLLIB::Fast_distance: { if (!_.IsFloatScalarType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar type"; } const uint32_t p0_type = _.GetOperandTypeId(inst, 4); if (!_.IsFloatScalarOrVectorType(p0_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P0 to be of float scalar or vector type"; } const uint32_t num_components = _.GetDimension(p0_type); if (num_components > 4) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P0 to have no more than 4 components"; } if (result_type != _.GetComponentType(p0_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P0 component type to be equal to " << "Result Type"; @@ -1314,7 +1327,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t p1_type = _.GetOperandTypeId(inst, 5); if (p0_type != p1_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operands P0 and P1 to be of the same type"; } @@ -1324,27 +1337,27 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Length: case OpenCLLIB::Fast_length: { if (!_.IsFloatScalarType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar type"; } const uint32_t p_type = _.GetOperandTypeId(inst, 4); if (!_.IsFloatScalarOrVectorType(p_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P to be a float scalar or vector"; } const uint32_t num_components = _.GetDimension(p_type); if (num_components > 4) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P to have no more than 4 components"; } if (result_type != _.GetComponentType(p_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P component type to be equal to Result " "Type"; @@ -1355,21 +1368,21 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Normalize: case OpenCLLIB::Fast_normalize: { if (!_.IsFloatScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar or vector type"; } const uint32_t num_components = _.GetDimension(result_type); if (num_components > 4) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to have no more than 4 components"; } const uint32_t p_type = _.GetOperandTypeId(inst, 4); if (p_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P type to be equal to Result Type"; } @@ -1379,7 +1392,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Bitselect: { if (!_.IsFloatScalarOrVectorType(result_type) && !_.IsIntScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be an int or float scalar or " "vector type"; @@ -1387,7 +1400,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t num_components = _.GetDimension(result_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or a vector with 2, " "3, 4, 8 or 16 components"; @@ -1397,7 +1410,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, ++operand_index) { const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); if (result_type != operand_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected types of all operands to be equal to Result " "Type"; @@ -1409,7 +1422,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Select: { if (!_.IsFloatScalarOrVectorType(result_type) && !_.IsIntScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be an int or float scalar or " "vector type"; @@ -1417,7 +1430,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t num_components = _.GetDimension(result_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or a vector with 2, " "3, 4, 8 or 16 components"; @@ -1428,32 +1441,32 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t c_type = _.GetOperandTypeId(inst, 6); if (result_type != a_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand A type to be equal to Result Type"; } if (result_type != b_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand B type to be equal to Result Type"; } if (!_.IsIntScalarOrVectorType(c_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand C to be an int scalar or vector"; } if (num_components != _.GetDimension(c_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand C to have the same number of components " "as Result Type"; } if (_.GetBitWidth(result_type) != _.GetBitWidth(c_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand C to have the same bit width as Result " "Type"; @@ -1464,14 +1477,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Vloadn: { if (!_.IsFloatVectorType(result_type) && !_.IsIntVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be an int or float vector type"; } const uint32_t num_components = _.GetDimension(result_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to have 2, 3, 4, 8 or 16 components"; } @@ -1481,14 +1494,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t size_t_bit_width = GetSizeTBitWidth(_); if (!size_t_bit_width) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << " can only be used with physical addressing models"; } if (!_.IsIntScalarType(offset_type) || _.GetBitWidth(offset_type) != size_t_bit_width) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Offset to be of type size_t (" << size_t_bit_width @@ -1498,29 +1511,29 @@ spv_result_t ExtInstPass(ValidationState_t& _, uint32_t p_storage_class = 0; uint32_t p_data_type = 0; if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P to be a pointer"; } if (p_storage_class != SpvStorageClassUniformConstant && p_storage_class != SpvStorageClassGeneric) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P storage class to be UniformConstant or " "Generic"; } if (_.GetComponentType(result_type) != p_data_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P data type to be equal to component " "type of Result Type"; } - const uint32_t n_value = inst->words[7]; + const uint32_t n_value = inst->word(7); if (num_components != n_value) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected literal N to be equal to the number of " "components of Result Type"; @@ -1530,7 +1543,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Vstoren: { if (_.GetIdOpcode(result_type) != SpvOpTypeVoid) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": expected Result Type to be void"; } @@ -1539,28 +1552,28 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t p_type = _.GetOperandTypeId(inst, 6); if (!_.IsFloatVectorType(data_type) && !_.IsIntVectorType(data_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Data to be an int or float vector"; } const uint32_t num_components = _.GetDimension(data_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Data to have 2, 3, 4, 8 or 16 components"; } const uint32_t size_t_bit_width = GetSizeTBitWidth(_); if (!size_t_bit_width) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << " can only be used with physical addressing models"; } if (!_.IsIntScalarType(offset_type) || _.GetBitWidth(offset_type) != size_t_bit_width) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Offset to be of type size_t (" << size_t_bit_width @@ -1570,19 +1583,19 @@ spv_result_t ExtInstPass(ValidationState_t& _, uint32_t p_storage_class = 0; uint32_t p_data_type = 0; if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P to be a pointer"; } if (p_storage_class != SpvStorageClassGeneric) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P storage class to be Generic"; } if (_.GetComponentType(data_type) != p_data_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P data type to be equal to the type of " "operand Data components"; @@ -1592,7 +1605,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Vload_half: { if (!_.IsFloatScalarType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar type"; } @@ -1602,14 +1615,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t size_t_bit_width = GetSizeTBitWidth(_); if (!size_t_bit_width) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << " can only be used with physical addressing models"; } if (!_.IsIntScalarType(offset_type) || _.GetBitWidth(offset_type) != size_t_bit_width) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Offset to be of type size_t (" << size_t_bit_width @@ -1619,7 +1632,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, uint32_t p_storage_class = 0; uint32_t p_data_type = 0; if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P to be a pointer"; } @@ -1629,7 +1642,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, p_storage_class != SpvStorageClassCrossWorkgroup && p_storage_class != SpvStorageClassWorkgroup && p_storage_class != SpvStorageClassFunction) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P storage class to be UniformConstant, " "Generic, CrossWorkgroup, Workgroup or Function"; @@ -1637,7 +1650,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (!_.IsFloatScalarType(p_data_type) || _.GetBitWidth(p_data_type) != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P data type to be 16-bit float scalar"; } @@ -1647,14 +1660,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Vload_halfn: case OpenCLLIB::Vloada_halfn: { if (!_.IsFloatVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float vector type"; } const uint32_t num_components = _.GetDimension(result_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to have 2, 3, 4, 8 or 16 components"; } @@ -1664,14 +1677,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t size_t_bit_width = GetSizeTBitWidth(_); if (!size_t_bit_width) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << " can only be used with physical addressing models"; } if (!_.IsIntScalarType(offset_type) || _.GetBitWidth(offset_type) != size_t_bit_width) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Offset to be of type size_t (" << size_t_bit_width @@ -1681,7 +1694,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, uint32_t p_storage_class = 0; uint32_t p_data_type = 0; if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P to be a pointer"; } @@ -1691,7 +1704,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, p_storage_class != SpvStorageClassCrossWorkgroup && p_storage_class != SpvStorageClassWorkgroup && p_storage_class != SpvStorageClassFunction) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P storage class to be UniformConstant, " "Generic, CrossWorkgroup, Workgroup or Function"; @@ -1699,14 +1712,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (!_.IsFloatScalarType(p_data_type) || _.GetBitWidth(p_data_type) != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P data type to be 16-bit float scalar"; } - const uint32_t n_value = inst->words[7]; + const uint32_t n_value = inst->word(7); if (num_components != n_value) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected literal N to be equal to the number of " "components of Result Type"; @@ -1721,7 +1734,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Vstorea_halfn: case OpenCLLIB::Vstorea_halfn_r: { if (_.GetIdOpcode(result_type) != SpvOpTypeVoid) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": expected Result Type to be void"; } @@ -1734,14 +1747,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, ext_inst_key == OpenCLLIB::Vstore_half_r) { if (!_.IsFloatScalarType(data_type) || (data_type_bit_width != 32 && data_type_bit_width != 64)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Data to be a 32 or 64-bit float scalar"; } } else { if (!_.IsFloatVectorType(data_type) || (data_type_bit_width != 32 && data_type_bit_width != 64)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Data to be a 32 or 64-bit float vector"; } @@ -1749,7 +1762,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t num_components = _.GetDimension(data_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Data to have 2, 3, 4, 8 or 16 components"; } @@ -1757,14 +1770,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t size_t_bit_width = GetSizeTBitWidth(_); if (!size_t_bit_width) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << " can only be used with physical addressing models"; } if (!_.IsIntScalarType(offset_type) || _.GetBitWidth(offset_type) != size_t_bit_width) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Offset to be of type size_t (" << size_t_bit_width @@ -1774,7 +1787,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, uint32_t p_storage_class = 0; uint32_t p_data_type = 0; if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P to be a pointer"; } @@ -1783,7 +1796,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, p_storage_class != SpvStorageClassCrossWorkgroup && p_storage_class != SpvStorageClassWorkgroup && p_storage_class != SpvStorageClassFunction) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P storage class to be Generic, " "CrossWorkgroup, Workgroup or Function"; @@ -1791,7 +1804,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (!_.IsFloatScalarType(p_data_type) || _.GetBitWidth(p_data_type) != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P data type to be 16-bit float scalar"; } @@ -1804,7 +1817,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Shuffle2: { if (!_.IsFloatVectorType(result_type) && !_.IsIntVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be an int or float vector type"; } @@ -1812,7 +1825,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t result_num_components = _.GetDimension(result_type); if (result_num_components != 2 && result_num_components != 4 && result_num_components != 8 && result_num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to have 2, 4, 8 or 16 components"; } @@ -1823,7 +1836,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (ext_inst_key == OpenCLLIB::Shuffle2) { const uint32_t y_type = _.GetOperandTypeId(inst, operand_index++); if (x_type != y_type) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operands X and Y to be of the same type"; } @@ -1833,7 +1846,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, _.GetOperandTypeId(inst, operand_index++); if (!_.IsFloatVectorType(x_type) && !_.IsIntVectorType(x_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X to be an int or float vector"; } @@ -1841,7 +1854,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t x_num_components = _.GetDimension(x_type); if (x_num_components != 2 && x_num_components != 4 && x_num_components != 8 && x_num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X to have 2, 4, 8 or 16 components"; } @@ -1849,20 +1862,20 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t result_component_type = _.GetComponentType(result_type); if (result_component_type != _.GetComponentType(x_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand X and Result Type to have equal " "component types"; } if (!_.IsIntVectorType(shuffle_mask_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Shuffle Mask to be an int vector"; } if (result_num_components != _.GetDimension(shuffle_mask_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Shuffle Mask to have the same number of " "components as Result Type"; @@ -1870,7 +1883,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, if (_.GetBitWidth(result_component_type) != _.GetBitWidth(shuffle_mask_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Shuffle Mask components to have the same " "bit width as Result Type components"; @@ -1881,7 +1894,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Printf: { if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a 32-bit int type"; } @@ -1891,20 +1904,20 @@ spv_result_t ExtInstPass(ValidationState_t& _, uint32_t format_data_type = 0; if (!_.GetPointerTypeInfo(format_type, &format_data_type, &format_storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Format to be a pointer"; } if (format_storage_class != SpvStorageClassUniformConstant) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Format storage class to be UniformConstant"; } if (!_.IsIntScalarType(format_data_type) || _.GetBitWidth(format_data_type) != 8) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Format data type to be 8-bit int"; } @@ -1913,7 +1926,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, case OpenCLLIB::Prefetch: { if (_.GetIdOpcode(result_type) != SpvOpTypeVoid) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": expected Result Type to be void"; } @@ -1923,20 +1936,20 @@ spv_result_t ExtInstPass(ValidationState_t& _, uint32_t p_storage_class = 0; uint32_t p_data_type = 0; if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Ptr to be a pointer"; } if (p_storage_class != SpvStorageClassCrossWorkgroup) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Ptr storage class to be CrossWorkgroup"; } if (!_.IsFloatScalarOrVectorType(p_data_type) && !_.IsIntScalarOrVectorType(p_data_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Ptr data type to be int or float scalar or " "vector"; @@ -1944,7 +1957,7 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t num_components = _.GetDimension(p_data_type); if (num_components > 4 && num_components != 8 && num_components != 16) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a scalar or a vector with 2, " "3, 4, 8 or 16 components"; @@ -1952,14 +1965,14 @@ spv_result_t ExtInstPass(ValidationState_t& _, const uint32_t size_t_bit_width = GetSizeTBitWidth(_); if (!size_t_bit_width) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << " can only be used with physical addressing models"; } if (!_.IsIntScalarType(num_elements_type) || _.GetBitWidth(num_elements_type) != size_t_bit_width) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand Num Elements to be of type size_t (" << size_t_bit_width @@ -1973,4 +1986,5 @@ spv_result_t ExtInstPass(ValidationState_t& _, return SPV_SUCCESS; } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_function.cpp b/3rdparty/spirv-tools/source/val/validate_function.cpp new file mode 100644 index 000000000..39f00fedc --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_function.cpp @@ -0,0 +1,202 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include + +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { + const auto function_type_id = inst->GetOperandAs(3); + const auto function_type = _.FindDef(function_type_id); + if (!function_type || SpvOpTypeFunction != function_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunction Function Type '" << _.getIdName(function_type_id) + << "' is not a function type."; + } + + const auto return_id = function_type->GetOperandAs(1); + if (return_id != inst->type_id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunction Result Type '" << _.getIdName(inst->type_id()) + << "' does not match the Function Type's return type '" + << _.getIdName(return_id) << "'."; + } + + for (auto& pair : inst->uses()) { + const auto* use = pair.first; + const std::vector acceptable = { + SpvOpFunctionCall, + SpvOpEntryPoint, + SpvOpEnqueueKernel, + SpvOpGetKernelNDrangeSubGroupCount, + SpvOpGetKernelNDrangeMaxSubGroupSize, + SpvOpGetKernelWorkGroupSize, + SpvOpGetKernelPreferredWorkGroupSizeMultiple, + SpvOpGetKernelLocalSizeForSubgroupCount, + SpvOpGetKernelMaxNumSubgroups}; + if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) == + acceptable.end()) { + return _.diag(SPV_ERROR_INVALID_ID, use) + << "Invalid use of function result id " << _.getIdName(inst->id()) + << "."; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateFunctionParameter(ValidationState_t& _, + const Instruction* inst) { + // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place. + size_t param_index = 0; + size_t inst_num = inst->LineNum() - 1; + if (inst_num == 0) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Function parameter cannot be the first instruction."; + } + + auto func_inst = &_.ordered_instructions()[inst_num]; + while (--inst_num) { + func_inst = &_.ordered_instructions()[inst_num]; + if (func_inst->opcode() == SpvOpFunction) { + break; + } else if (func_inst->opcode() == SpvOpFunctionParameter) { + ++param_index; + } + } + + if (func_inst->opcode() != SpvOpFunction) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Function parameter must be preceded by a function."; + } + + const auto function_type_id = func_inst->GetOperandAs(3); + const auto function_type = _.FindDef(function_type_id); + if (!function_type) { + return _.diag(SPV_ERROR_INVALID_ID, func_inst) + << "Missing function type definition."; + } + if (param_index >= function_type->words().size() - 3) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Too many OpFunctionParameters for " << func_inst->id() + << ": expected " << function_type->words().size() - 3 + << " based on the function's type"; + } + + const auto param_type = + _.FindDef(function_type->GetOperandAs(param_index + 2)); + if (!param_type || inst->type_id() != param_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionParameter Result Type '" + << _.getIdName(inst->type_id()) + << "' does not match the OpTypeFunction parameter " + "type of the same index."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateFunctionCall(ValidationState_t& _, + const Instruction* inst) { + const auto function_id = inst->GetOperandAs(2); + const auto function = _.FindDef(function_id); + if (!function || SpvOpFunction != function->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionCall Function '" << _.getIdName(function_id) + << "' is not a function."; + } + + auto return_type = _.FindDef(function->type_id()); + if (!return_type || return_type->id() != inst->type_id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionCall Result Type '" + << _.getIdName(inst->type_id()) + << "'s type does not match Function '" + << _.getIdName(return_type->id()) << "'s return type."; + } + + const auto function_type_id = function->GetOperandAs(3); + const auto function_type = _.FindDef(function_type_id); + if (!function_type || function_type->opcode() != SpvOpTypeFunction) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Missing function type definition."; + } + + const auto function_call_arg_count = inst->words().size() - 4; + const auto function_param_count = function_type->words().size() - 3; + if (function_param_count != function_call_arg_count) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionCall Function 's parameter count does not match " + "the argument count."; + } + + for (size_t argument_index = 3, param_index = 2; + argument_index < inst->operands().size(); + argument_index++, param_index++) { + const auto argument_id = inst->GetOperandAs(argument_index); + const auto argument = _.FindDef(argument_id); + if (!argument) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Missing argument " << argument_index - 3 << " definition."; + } + + const auto argument_type = _.FindDef(argument->type_id()); + if (!argument_type) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Missing argument " << argument_index - 3 + << " type definition."; + } + + const auto parameter_type_id = + function_type->GetOperandAs(param_index); + const auto parameter_type = _.FindDef(parameter_type_id); + if (!parameter_type || argument_type->id() != parameter_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionCall Argument '" << _.getIdName(argument_id) + << "'s type does not match Function '" + << _.getIdName(parameter_type_id) << "'s parameter type."; + } + } + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpFunction: + if (auto error = ValidateFunction(_, inst)) return error; + break; + case SpvOpFunctionParameter: + if (auto error = ValidateFunctionParameter(_, inst)) return error; + break; + case SpvOpFunctionCall: + if (auto error = ValidateFunctionCall(_, inst)) return error; + break; + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_id.cpp b/3rdparty/spirv-tools/source/val/validate_id.cpp new file mode 100644 index 000000000..6359ab600 --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_id.cpp @@ -0,0 +1,204 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/diagnostic.h" +#include "source/instruction.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/spirv_validator_options.h" +#include "source/val/function.h" +#include "source/val/validation_state.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace val { + +spv_result_t UpdateIdUse(ValidationState_t& _, const Instruction* inst) { + for (auto& operand : inst->operands()) { + const spv_operand_type_t& type = operand.type; + const uint32_t operand_id = inst->word(operand.offset); + if (spvIsIdType(type) && type != SPV_OPERAND_TYPE_RESULT_ID) { + if (auto def = _.FindDef(operand_id)) + def->RegisterUse(inst, operand.offset); + } + } + + return SPV_SUCCESS; +} + +/// This function checks all ID definitions dominate their use in the CFG. +/// +/// This function will iterate over all ID definitions that are defined in the +/// functions of a module and make sure that the definitions appear in a +/// block that dominates their use. +/// +/// NOTE: This function does NOT check module scoped functions which are +/// checked during the initial binary parse in the IdPass below +spv_result_t CheckIdDefinitionDominateUse(const ValidationState_t& _) { + std::vector phi_instructions; + std::unordered_set phi_ids; + for (const auto& inst : _.ordered_instructions()) { + if (inst.id() == 0) continue; + if (const Function* func = inst.function()) { + if (const BasicBlock* block = inst.block()) { + if (!block->reachable()) continue; + // If the Id is defined within a block then make sure all references to + // that Id appear in a blocks that are dominated by the defining block + for (auto& use_index_pair : inst.uses()) { + const Instruction* use = use_index_pair.first; + if (const BasicBlock* use_block = use->block()) { + if (use_block->reachable() == false) continue; + if (use->opcode() == SpvOpPhi) { + if (phi_ids.insert(use->id()).second) { + phi_instructions.push_back(use); + } + } else if (!block->dominates(*use->block())) { + return _.diag(SPV_ERROR_INVALID_ID, use_block->label()) + << "ID " << _.getIdName(inst.id()) << " defined in block " + << _.getIdName(block->id()) + << " does not dominate its use in block " + << _.getIdName(use_block->id()); + } + } + } + } else { + // If the Ids defined within a function but not in a block(i.e. function + // parameters, block ids), then make sure all references to that Id + // appear within the same function + for (auto use : inst.uses()) { + const Instruction* user = use.first; + if (user->function() && user->function() != func) { + return _.diag(SPV_ERROR_INVALID_ID, _.FindDef(func->id())) + << "ID " << _.getIdName(inst.id()) << " used in function " + << _.getIdName(user->function()->id()) + << " is used outside of it's defining function " + << _.getIdName(func->id()); + } + } + } + } + // NOTE: Ids defined outside of functions must appear before they are used + // This check is being performed in the IdPass function + } + + // Check all OpPhi parent blocks are dominated by the variable's defining + // blocks + for (const Instruction* phi : phi_instructions) { + if (phi->block()->reachable() == false) continue; + for (size_t i = 3; i < phi->operands().size(); i += 2) { + const Instruction* variable = _.FindDef(phi->word(i)); + const BasicBlock* parent = + phi->function()->GetBlock(phi->word(i + 1)).first; + if (variable->block() && parent->reachable() && + !variable->block()->dominates(*parent)) { + return _.diag(SPV_ERROR_INVALID_ID, phi) + << "In OpPhi instruction " << _.getIdName(phi->id()) << ", ID " + << _.getIdName(variable->id()) + << " definition does not dominate its parent " + << _.getIdName(parent->id()); + } + } + } + + return SPV_SUCCESS; +} + +// Performs SSA validation on the IDs of an instruction. The +// can_have_forward_declared_ids functor should return true if the +// instruction operand's ID can be forward referenced. +spv_result_t IdPass(ValidationState_t& _, Instruction* inst) { + auto can_have_forward_declared_ids = + spvOperandCanBeForwardDeclaredFunction(inst->opcode()); + + // Keep track of a result id defined by this instruction. 0 means it + // does not define an id. + uint32_t result_id = 0; + + for (unsigned i = 0; i < inst->operands().size(); i++) { + const spv_parsed_operand_t& operand = inst->operand(i); + const spv_operand_type_t& type = operand.type; + // We only care about Id operands, which are a single word. + const uint32_t operand_word = inst->word(operand.offset); + + auto ret = SPV_ERROR_INTERNAL; + switch (type) { + case SPV_OPERAND_TYPE_RESULT_ID: + // NOTE: Multiple Id definitions are being checked by the binary parser. + // + // Defer undefined-forward-reference removal until after we've analyzed + // the remaining operands to this instruction. Deferral only matters + // for OpPhi since it's the only case where it defines its own forward + // reference. Other instructions that can have forward references + // either don't define a value or the forward reference is to a function + // Id (and hence defined outside of a function body). + result_id = operand_word; + // NOTE: The result Id is added (in RegisterInstruction) *after* all of + // the other Ids have been checked to avoid premature use in the same + // instruction. + ret = SPV_SUCCESS; + break; + case SPV_OPERAND_TYPE_ID: + case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: + case SPV_OPERAND_TYPE_SCOPE_ID: + if (_.IsDefinedId(operand_word)) { + ret = SPV_SUCCESS; + } else if (can_have_forward_declared_ids(i)) { + ret = _.ForwardDeclareId(operand_word); + } else { + ret = _.diag(SPV_ERROR_INVALID_ID, inst) + << "ID " << _.getIdName(operand_word) + << " has not been defined"; + } + break; + case SPV_OPERAND_TYPE_TYPE_ID: + if (_.IsDefinedId(operand_word)) { + auto* def = _.FindDef(operand_word); + if (!spvOpcodeGeneratesType(def->opcode())) { + ret = _.diag(SPV_ERROR_INVALID_ID, inst) + << "ID " << _.getIdName(operand_word) << " is not a type id"; + } else { + ret = SPV_SUCCESS; + } + } else { + ret = _.diag(SPV_ERROR_INVALID_ID, inst) + << "ID " << _.getIdName(operand_word) + << " has not been defined"; + } + break; + default: + ret = SPV_SUCCESS; + break; + } + if (SPV_SUCCESS != ret) return ret; + } + if (result_id) _.RemoveIfForwardDeclared(result_id); + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_image.cpp b/3rdparty/spirv-tools/source/val/validate_image.cpp new file mode 100644 index 000000000..2c020ed1b --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_image.cpp @@ -0,0 +1,1640 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of image instructions. + +#include "source/val/validate.h" + +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/util/bitutils.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// Performs compile time check that all SpvImageOperandsXXX cases are handled in +// this module. If SpvImageOperandsXXX list changes, this function will fail the +// build. +// For all other purposes this is a dummy function. +bool CheckAllImageOperandsHandled() { + SpvImageOperandsMask enum_val = SpvImageOperandsBiasMask; + + // Some improvised code to prevent the compiler from considering enum_val + // constant and optimizing the switch away. + uint32_t stack_var = 0; + if (reinterpret_cast(&stack_var) % 256) + enum_val = SpvImageOperandsLodMask; + + switch (enum_val) { + // Please update the validation rules in this module if you are changing + // the list of image operands, and add new enum values to this switch. + case SpvImageOperandsMaskNone: + return false; + case SpvImageOperandsBiasMask: + case SpvImageOperandsLodMask: + case SpvImageOperandsGradMask: + case SpvImageOperandsConstOffsetMask: + case SpvImageOperandsOffsetMask: + case SpvImageOperandsConstOffsetsMask: + case SpvImageOperandsSampleMask: + case SpvImageOperandsMinLodMask: + return true; + } + return false; +} + +// Used by GetImageTypeInfo. See OpTypeImage spec for more information. +struct ImageTypeInfo { + uint32_t sampled_type = 0; + SpvDim dim = SpvDimMax; + uint32_t depth = 0; + uint32_t arrayed = 0; + uint32_t multisampled = 0; + uint32_t sampled = 0; + SpvImageFormat format = SpvImageFormatMax; + SpvAccessQualifier access_qualifier = SpvAccessQualifierMax; +}; + +// Provides information on image type. |id| should be object of either +// OpTypeImage or OpTypeSampledImage type. Returns false in case of failure +// (not a valid id, failed to parse the instruction, etc). +bool GetImageTypeInfo(const ValidationState_t& _, uint32_t id, + ImageTypeInfo* info) { + if (!id || !info) return false; + + const Instruction* inst = _.FindDef(id); + assert(inst); + + if (inst->opcode() == SpvOpTypeSampledImage) { + inst = _.FindDef(inst->word(2)); + assert(inst); + } + + if (inst->opcode() != SpvOpTypeImage) return false; + + const size_t num_words = inst->words().size(); + if (num_words != 9 && num_words != 10) return false; + + info->sampled_type = inst->word(2); + info->dim = static_cast(inst->word(3)); + info->depth = inst->word(4); + info->arrayed = inst->word(5); + info->multisampled = inst->word(6); + info->sampled = inst->word(7); + info->format = static_cast(inst->word(8)); + info->access_qualifier = num_words < 10 + ? SpvAccessQualifierMax + : static_cast(inst->word(9)); + return true; +} + +bool IsImplicitLod(SpvOp opcode) { + switch (opcode) { + case SpvOpImageSampleImplicitLod: + case SpvOpImageSampleDrefImplicitLod: + case SpvOpImageSampleProjImplicitLod: + case SpvOpImageSampleProjDrefImplicitLod: + case SpvOpImageSparseSampleImplicitLod: + case SpvOpImageSparseSampleDrefImplicitLod: + case SpvOpImageSparseSampleProjImplicitLod: + case SpvOpImageSparseSampleProjDrefImplicitLod: + return true; + default: + break; + } + return false; +} + +bool IsExplicitLod(SpvOp opcode) { + switch (opcode) { + case SpvOpImageSampleExplicitLod: + case SpvOpImageSampleDrefExplicitLod: + case SpvOpImageSampleProjExplicitLod: + case SpvOpImageSampleProjDrefExplicitLod: + case SpvOpImageSparseSampleExplicitLod: + case SpvOpImageSparseSampleDrefExplicitLod: + case SpvOpImageSparseSampleProjExplicitLod: + case SpvOpImageSparseSampleProjDrefExplicitLod: + return true; + default: + break; + } + return false; +} + +// Returns true if the opcode is a Image instruction which applies +// homogenous projection to the coordinates. +bool IsProj(SpvOp opcode) { + switch (opcode) { + case SpvOpImageSampleProjImplicitLod: + case SpvOpImageSampleProjDrefImplicitLod: + case SpvOpImageSparseSampleProjImplicitLod: + case SpvOpImageSparseSampleProjDrefImplicitLod: + case SpvOpImageSampleProjExplicitLod: + case SpvOpImageSampleProjDrefExplicitLod: + case SpvOpImageSparseSampleProjExplicitLod: + case SpvOpImageSparseSampleProjDrefExplicitLod: + return true; + default: + break; + } + return false; +} + +// Returns the number of components in a coordinate used to access a texel in +// a single plane of an image with the given parameters. +uint32_t GetPlaneCoordSize(const ImageTypeInfo& info) { + uint32_t plane_size = 0; + // If this switch breaks your build, please add new values below. + switch (info.dim) { + case SpvDim1D: + case SpvDimBuffer: + plane_size = 1; + break; + case SpvDim2D: + case SpvDimRect: + case SpvDimSubpassData: + plane_size = 2; + break; + case SpvDim3D: + case SpvDimCube: + // For Cube direction vector is used instead of UV. + plane_size = 3; + break; + case SpvDimMax: + assert(0); + break; + } + + return plane_size; +} + +// Returns minimal number of coordinates based on image dim, arrayed and whether +// the instruction uses projection coordinates. +uint32_t GetMinCoordSize(SpvOp opcode, const ImageTypeInfo& info) { + if (info.dim == SpvDimCube && + (opcode == SpvOpImageRead || opcode == SpvOpImageWrite || + opcode == SpvOpImageSparseRead)) { + // These opcodes use UV for Cube, not direction vector. + return 3; + } + + return GetPlaneCoordSize(info) + info.arrayed + (IsProj(opcode) ? 1 : 0); +} + +// Checks ImageOperand bitfield and respective operands. +spv_result_t ValidateImageOperands(ValidationState_t& _, + const Instruction* inst, + const ImageTypeInfo& info, uint32_t mask, + uint32_t word_index) { + static const bool kAllImageOperandsHandled = CheckAllImageOperandsHandled(); + (void)kAllImageOperandsHandled; + + const SpvOp opcode = inst->opcode(); + const size_t num_words = inst->words().size(); + + size_t expected_num_image_operand_words = spvtools::utils::CountSetBits(mask); + if (mask & SpvImageOperandsGradMask) { + // Grad uses two words. + ++expected_num_image_operand_words; + } + + if (expected_num_image_operand_words != num_words - word_index) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Number of image operand ids doesn't correspond to the bit mask"; + } + + if (spvtools::utils::CountSetBits( + mask & (SpvImageOperandsOffsetMask | SpvImageOperandsConstOffsetMask | + SpvImageOperandsConstOffsetsMask)) > 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operands Offset, ConstOffset, ConstOffsets cannot be used " + << "together"; + } + + const bool is_implicit_lod = IsImplicitLod(opcode); + const bool is_explicit_lod = IsExplicitLod(opcode); + + // The checks should be done in the order of definition of OperandImage. + + if (mask & SpvImageOperandsBiasMask) { + if (!is_implicit_lod) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Bias can only be used with ImplicitLod opcodes"; + } + + const uint32_t type_id = _.GetTypeId(inst->word(word_index++)); + if (!_.IsFloatScalarType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Bias to be float scalar"; + } + + if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && + info.dim != SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Bias requires 'Dim' parameter to be 1D, 2D, 3D " + "or Cube"; + } + + if (info.multisampled != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Bias requires 'MS' parameter to be 0"; + } + } + + if (mask & SpvImageOperandsLodMask) { + if (!is_explicit_lod && opcode != SpvOpImageFetch && + opcode != SpvOpImageSparseFetch) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Lod can only be used with ExplicitLod opcodes " + << "and OpImageFetch"; + } + + if (mask & SpvImageOperandsGradMask) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand bits Lod and Grad cannot be set at the same " + "time"; + } + + const uint32_t type_id = _.GetTypeId(inst->word(word_index++)); + if (is_explicit_lod) { + if (!_.IsFloatScalarType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Lod to be float scalar when used " + << "with ExplicitLod"; + } + } else { + if (!_.IsIntScalarType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Lod to be int scalar when used with " + << "OpImageFetch"; + } + } + + if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && + info.dim != SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Lod requires 'Dim' parameter to be 1D, 2D, 3D " + "or Cube"; + } + + if (info.multisampled != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Lod requires 'MS' parameter to be 0"; + } + } + + if (mask & SpvImageOperandsGradMask) { + if (!is_explicit_lod) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Grad can only be used with ExplicitLod opcodes"; + } + + const uint32_t dx_type_id = _.GetTypeId(inst->word(word_index++)); + const uint32_t dy_type_id = _.GetTypeId(inst->word(word_index++)); + if (!_.IsFloatScalarOrVectorType(dx_type_id) || + !_.IsFloatScalarOrVectorType(dy_type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected both Image Operand Grad ids to be float scalars or " + << "vectors"; + } + + const uint32_t plane_size = GetPlaneCoordSize(info); + const uint32_t dx_size = _.GetDimension(dx_type_id); + const uint32_t dy_size = _.GetDimension(dy_type_id); + if (plane_size != dx_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Grad dx to have " << plane_size + << " components, but given " << dx_size; + } + + if (plane_size != dy_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Grad dy to have " << plane_size + << " components, but given " << dy_size; + } + + if (info.multisampled != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Grad requires 'MS' parameter to be 0"; + } + } + + if (mask & SpvImageOperandsConstOffsetMask) { + if (info.dim == SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand ConstOffset cannot be used with Cube Image " + "'Dim'"; + } + + const uint32_t id = inst->word(word_index++); + const uint32_t type_id = _.GetTypeId(id); + if (!_.IsIntScalarOrVectorType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffset to be int scalar or " + << "vector"; + } + + if (!spvOpcodeIsConstant(_.GetIdOpcode(id))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffset to be a const object"; + } + + const uint32_t plane_size = GetPlaneCoordSize(info); + const uint32_t offset_size = _.GetDimension(type_id); + if (plane_size != offset_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffset to have " << plane_size + << " components, but given " << offset_size; + } + } + + if (mask & SpvImageOperandsOffsetMask) { + if (info.dim == SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Offset cannot be used with Cube Image 'Dim'"; + } + + const uint32_t id = inst->word(word_index++); + const uint32_t type_id = _.GetTypeId(id); + if (!_.IsIntScalarOrVectorType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Offset to be int scalar or " + << "vector"; + } + + const uint32_t plane_size = GetPlaneCoordSize(info); + const uint32_t offset_size = _.GetDimension(type_id); + if (plane_size != offset_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Offset to have " << plane_size + << " components, but given " << offset_size; + } + } + + if (mask & SpvImageOperandsConstOffsetsMask) { + if (opcode != SpvOpImageGather && opcode != SpvOpImageDrefGather && + opcode != SpvOpImageSparseGather && + opcode != SpvOpImageSparseDrefGather) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand ConstOffsets can only be used with " + "OpImageGather and OpImageDrefGather"; + } + + if (info.dim == SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand ConstOffsets cannot be used with Cube Image " + "'Dim'"; + } + + const uint32_t id = inst->word(word_index++); + const uint32_t type_id = _.GetTypeId(id); + const Instruction* type_inst = _.FindDef(type_id); + assert(type_inst); + + if (type_inst->opcode() != SpvOpTypeArray) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffsets to be an array of size 4"; + } + + uint64_t array_size = 0; + if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) { + assert(0 && "Array type definition is corrupt"); + } + + if (array_size != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffsets to be an array of size 4"; + } + + const uint32_t component_type = type_inst->word(2); + if (!_.IsIntVectorType(component_type) || + _.GetDimension(component_type) != 2) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffsets array componenets to be " + "int vectors of size 2"; + } + + if (!spvOpcodeIsConstant(_.GetIdOpcode(id))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffsets to be a const object"; + } + } + + if (mask & SpvImageOperandsSampleMask) { + if (opcode != SpvOpImageFetch && opcode != SpvOpImageRead && + opcode != SpvOpImageWrite && opcode != SpvOpImageSparseFetch && + opcode != SpvOpImageSparseRead) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Sample can only be used with OpImageFetch, " + << "OpImageRead, OpImageWrite, OpImageSparseFetch and " + << "OpImageSparseRead"; + } + + if (info.multisampled == 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Sample requires non-zero 'MS' parameter"; + } + + const uint32_t type_id = _.GetTypeId(inst->word(word_index++)); + if (!_.IsIntScalarType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Sample to be int scalar"; + } + } + + if (mask & SpvImageOperandsMinLodMask) { + if (!is_implicit_lod && !(mask & SpvImageOperandsGradMask)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand MinLod can only be used with ImplicitLod " + << "opcodes or together with Image Operand Grad"; + } + + const uint32_t type_id = _.GetTypeId(inst->word(word_index++)); + if (!_.IsFloatScalarType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand MinLod to be float scalar"; + } + + if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && + info.dim != SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand MinLod requires 'Dim' parameter to be 1D, 2D, " + "3D or Cube"; + } + + if (info.multisampled != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand MinLod requires 'MS' parameter to be 0"; + } + } + + return SPV_SUCCESS; +} + +// Checks some of the validation rules which are common to multiple opcodes. +spv_result_t ValidateImageCommon(ValidationState_t& _, const Instruction* inst, + const ImageTypeInfo& info) { + const SpvOp opcode = inst->opcode(); + if (IsProj(opcode)) { + if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && + info.dim != SpvDimRect) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Dim' parameter to be 1D, 2D, 3D or Rect"; + } + + if (info.multisampled != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Image 'MS' parameter to be 0"; + } + + if (info.arrayed != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Image 'arrayed' parameter to be 0"; + } + } + + if (opcode == SpvOpImageRead || opcode == SpvOpImageSparseRead || + opcode == SpvOpImageWrite) { + if (info.sampled == 0) { + } else if (info.sampled == 2) { + if (info.dim == SpvDim1D && !_.HasCapability(SpvCapabilityImage1D)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability Image1D is required to access storage image"; + } else if (info.dim == SpvDimRect && + !_.HasCapability(SpvCapabilityImageRect)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability ImageRect is required to access storage image"; + } else if (info.dim == SpvDimBuffer && + !_.HasCapability(SpvCapabilityImageBuffer)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability ImageBuffer is required to access storage image"; + } else if (info.dim == SpvDimCube && info.arrayed == 1 && + !_.HasCapability(SpvCapabilityImageCubeArray)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability ImageCubeArray is required to access " + << "storage image"; + } + + if (info.multisampled == 1 && + !_.HasCapability(SpvCapabilityImageMSArray)) { +#if 0 + // TODO(atgoo@github.com) The description of this rule in the spec + // is unclear and Glslang doesn't declare ImageMSArray. Need to clarify + // and reenable. + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability ImageMSArray is required to access storage " + << "image"; +#endif + } + } else { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled' parameter to be 0 or 2"; + } + } + + return SPV_SUCCESS; +} + +// Returns true if opcode is *ImageSparse*, false otherwise. +bool IsSparse(SpvOp opcode) { + switch (opcode) { + case SpvOpImageSparseSampleImplicitLod: + case SpvOpImageSparseSampleExplicitLod: + case SpvOpImageSparseSampleDrefImplicitLod: + case SpvOpImageSparseSampleDrefExplicitLod: + case SpvOpImageSparseSampleProjImplicitLod: + case SpvOpImageSparseSampleProjExplicitLod: + case SpvOpImageSparseSampleProjDrefImplicitLod: + case SpvOpImageSparseSampleProjDrefExplicitLod: + case SpvOpImageSparseFetch: + case SpvOpImageSparseGather: + case SpvOpImageSparseDrefGather: + case SpvOpImageSparseTexelsResident: + case SpvOpImageSparseRead: { + return true; + } + + default: { return false; } + } + + return false; +} + +// Checks sparse image opcode result type and returns the second struct member. +// Returns inst.type_id for non-sparse image opcodes. +// Not valid for sparse image opcodes which do not return a struct. +spv_result_t GetActualResultType(ValidationState_t& _, const Instruction* inst, + uint32_t* actual_result_type) { + const SpvOp opcode = inst->opcode(); + + if (IsSparse(opcode)) { + const Instruction* const type_inst = _.FindDef(inst->type_id()); + assert(type_inst); + + if (!type_inst || type_inst->opcode() != SpvOpTypeStruct) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be OpTypeStruct"; + } + + if (type_inst->words().size() != 4 || + !_.IsIntScalarType(type_inst->word(2))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be a struct containing an int " + "scalar and a texel"; + } + + *actual_result_type = type_inst->word(3); + } else { + *actual_result_type = inst->type_id(); + } + + return SPV_SUCCESS; +} + +// Returns a string describing actual result type of an opcode. +// Not valid for sparse image opcodes which do not return a struct. +const char* GetActualResultTypeStr(SpvOp opcode) { + if (IsSparse(opcode)) return "Result Type's second member"; + return "Result Type"; +} + +spv_result_t ValidateTypeImage(ValidationState_t& _, const Instruction* inst) { + assert(inst->type_id() == 0); + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, inst->word(1), &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (spvIsVulkanEnv(_.context()->target_env)) { + if ((!_.IsFloatScalarType(info.sampled_type) && + !_.IsIntScalarType(info.sampled_type)) || + 32 != _.GetBitWidth(info.sampled_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sampled Type to be a 32-bit int or float " + "scalar type for Vulkan environment"; + } + } else { + const SpvOp sampled_type_opcode = _.GetIdOpcode(info.sampled_type); + if (sampled_type_opcode != SpvOpTypeVoid && + sampled_type_opcode != SpvOpTypeInt && + sampled_type_opcode != SpvOpTypeFloat) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sampled Type to be either void or" + << " numerical scalar type"; + } + } + + // Dim is checked elsewhere. + + if (info.depth > 2) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Invalid Depth " << info.depth << " (must be 0, 1 or 2)"; + } + + if (info.arrayed > 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Invalid Arrayed " << info.arrayed << " (must be 0 or 1)"; + } + + if (info.multisampled > 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Invalid MS " << info.multisampled << " (must be 0 or 1)"; + } + + if (info.sampled > 2) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Invalid Sampled " << info.sampled << " (must be 0, 1 or 2)"; + } + + if (info.dim == SpvDimSubpassData) { + if (info.sampled != 2) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Dim SubpassData requires Sampled to be 2"; + } + + if (info.format != SpvImageFormatUnknown) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Dim SubpassData requires format Unknown"; + } + } + + // Format and Access Qualifier are checked elsewhere. + + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeSampledImage(ValidationState_t& _, + const Instruction* inst) { + const uint32_t image_type = inst->word(2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateSampledImage(ValidationState_t& _, + const Instruction* inst) { + if (_.GetIdOpcode(inst->type_id()) != SpvOpTypeSampledImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be OpTypeSampledImage."; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage."; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + // TODO(atgoo@github.com) Check compatibility of result type and received + // image. + + if (spvIsVulkanEnv(_.context()->target_env)) { + if (info.sampled != 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled' parameter to be 1 " + << "for Vulkan environment."; + } + } else { + if (info.sampled != 0 && info.sampled != 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled' parameter to be 0 or 1"; + } + } + + if (info.dim == SpvDimSubpassData) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Dim' parameter to be not SubpassData."; + } + + if (_.GetIdOpcode(_.GetOperandTypeId(inst, 3)) != SpvOpTypeSampler) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sampler to be of type OpTypeSampler"; + } + + // We need to validate 2 things: + // * All OpSampledImage instructions must be in the same block in which their + // Result are consumed. + // * Result from OpSampledImage instructions must not appear as operands + // to OpPhi instructions or OpSelect instructions, or any instructions other + // than the image lookup and query instructions specified to take an operand + // whose type is OpTypeSampledImage. + std::vector consumers = _.getSampledImageConsumers(inst->id()); + if (!consumers.empty()) { + for (auto consumer_id : consumers) { + const auto consumer_instr = _.FindDef(consumer_id); + const auto consumer_opcode = consumer_instr->opcode(); + if (consumer_instr->block() != inst->block()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "All OpSampledImage instructions must be in the same block " + "in " + "which their Result are consumed. OpSampledImage Result " + "Type '" + << _.getIdName(inst->id()) + << "' has a consumer in a different basic " + "block. The consumer instruction is '" + << _.getIdName(consumer_id) << "'."; + } + // TODO: The following check is incomplete. We should also check that the + // Sampled Image is not used by instructions that should not take + // SampledImage as an argument. We could find the list of valid + // instructions by scanning for "Sampled Image" in the operand description + // field in the grammar file. + if (consumer_opcode == SpvOpPhi || consumer_opcode == SpvOpSelect) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Result from OpSampledImage instruction must not appear " + "as " + "operands of Op" + << spvOpcodeString(static_cast(consumer_opcode)) << "." + << " Found result '" << _.getIdName(inst->id()) + << "' as an operand of '" << _.getIdName(consumer_id) + << "'."; + } + } + } + return SPV_SUCCESS; +} + +spv_result_t ValidateImageLod(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + uint32_t actual_result_type = 0; + if (spv_result_t error = GetActualResultType(_, inst, &actual_result_type)) { + return error; + } + + if (!_.IsIntVectorType(actual_result_type) && + !_.IsFloatVectorType(actual_result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to be int or float vector type"; + } + + if (_.GetDimension(actual_result_type) != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to have 4 components"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeSampledImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sampled Image to be of type OpTypeSampledImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (spv_result_t result = ValidateImageCommon(_, inst, info)) return result; + + if (_.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { + const uint32_t texel_component_type = + _.GetComponentType(actual_result_type); + if (texel_component_type != info.sampled_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled Type' to be the same as " + << GetActualResultTypeStr(opcode) << " components"; + } + } + + const uint32_t coord_type = _.GetOperandTypeId(inst, 3); + if ((opcode == SpvOpImageSampleExplicitLod || + opcode == SpvOpImageSparseSampleExplicitLod) && + _.HasCapability(SpvCapabilityKernel)) { + if (!_.IsFloatScalarOrVectorType(coord_type) && + !_.IsIntScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be int or float scalar or vector"; + } + } else { + if (!_.IsFloatScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be float scalar or vector"; + } + } + + const uint32_t min_coord_size = GetMinCoordSize(opcode, info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + + if (inst->words().size() <= 5) { + assert(IsImplicitLod(opcode)); + return SPV_SUCCESS; + } + + const uint32_t mask = inst->word(5); + if (spv_result_t result = + ValidateImageOperands(_, inst, info, mask, /* word_index = */ 6)) + return result; + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageDrefLod(ValidationState_t& _, + const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + uint32_t actual_result_type = 0; + if (spv_result_t error = GetActualResultType(_, inst, &actual_result_type)) { + return error; + } + + if (!_.IsIntScalarType(actual_result_type) && + !_.IsFloatScalarType(actual_result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to be int or float scalar type"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeSampledImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sampled Image to be of type OpTypeSampledImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (spv_result_t result = ValidateImageCommon(_, inst, info)) return result; + + if (actual_result_type != info.sampled_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled Type' to be the same as " + << GetActualResultTypeStr(opcode); + } + + const uint32_t coord_type = _.GetOperandTypeId(inst, 3); + if (!_.IsFloatScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be float scalar or vector"; + } + + const uint32_t min_coord_size = GetMinCoordSize(opcode, info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + + const uint32_t dref_type = _.GetOperandTypeId(inst, 4); + if (!_.IsFloatScalarType(dref_type) || _.GetBitWidth(dref_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Dref to be of 32-bit float type"; + } + + if (inst->words().size() <= 6) { + assert(IsImplicitLod(opcode)); + return SPV_SUCCESS; + } + + const uint32_t mask = inst->word(6); + if (spv_result_t result = + ValidateImageOperands(_, inst, info, mask, /* word_index = */ 7)) + return result; + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageFetch(ValidationState_t& _, const Instruction* inst) { + uint32_t actual_result_type = 0; + if (spv_result_t error = GetActualResultType(_, inst, &actual_result_type)) { + return error; + } + + const SpvOp opcode = inst->opcode(); + if (!_.IsIntVectorType(actual_result_type) && + !_.IsFloatVectorType(actual_result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to be int or float vector type"; + } + + if (_.GetDimension(actual_result_type) != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to have 4 components"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (_.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { + const uint32_t result_component_type = + _.GetComponentType(actual_result_type); + if (result_component_type != info.sampled_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled Type' to be the same as " + << GetActualResultTypeStr(opcode) << " components"; + } + } + + if (info.dim == SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Image 'Dim' cannot be Cube"; + } + + if (info.sampled != 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled' parameter to be 1"; + } + + const uint32_t coord_type = _.GetOperandTypeId(inst, 3); + if (!_.IsIntScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be int scalar or vector"; + } + + const uint32_t min_coord_size = GetMinCoordSize(opcode, info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + + if (inst->words().size() <= 5) return SPV_SUCCESS; + + const uint32_t mask = inst->word(5); + if (spv_result_t result = + ValidateImageOperands(_, inst, info, mask, /* word_index = */ 6)) + return result; + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageGather(ValidationState_t& _, + const Instruction* inst) { + uint32_t actual_result_type = 0; + if (spv_result_t error = GetActualResultType(_, inst, &actual_result_type)) + return error; + + const SpvOp opcode = inst->opcode(); + if (!_.IsIntVectorType(actual_result_type) && + !_.IsFloatVectorType(actual_result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to be int or float vector type"; + } + + if (_.GetDimension(actual_result_type) != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to have 4 components"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeSampledImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sampled Image to be of type OpTypeSampledImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (opcode == SpvOpImageDrefGather || opcode == SpvOpImageSparseDrefGather || + _.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { + const uint32_t result_component_type = + _.GetComponentType(actual_result_type); + if (result_component_type != info.sampled_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled Type' to be the same as " + << GetActualResultTypeStr(opcode) << " components"; + } + } + + if (info.dim != SpvDim2D && info.dim != SpvDimCube && + info.dim != SpvDimRect) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Dim' cannot be Cube"; + } + + const uint32_t coord_type = _.GetOperandTypeId(inst, 3); + if (!_.IsFloatScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be float scalar or vector"; + } + + const uint32_t min_coord_size = GetMinCoordSize(opcode, info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + + if (opcode == SpvOpImageGather || opcode == SpvOpImageSparseGather) { + const uint32_t component_index_type = _.GetOperandTypeId(inst, 4); + if (!_.IsIntScalarType(component_index_type) || + _.GetBitWidth(component_index_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Component to be 32-bit int scalar"; + } + } else { + assert(opcode == SpvOpImageDrefGather || + opcode == SpvOpImageSparseDrefGather); + const uint32_t dref_type = _.GetOperandTypeId(inst, 4); + if (!_.IsFloatScalarType(dref_type) || _.GetBitWidth(dref_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Dref to be of 32-bit float type"; + } + } + + if (inst->words().size() <= 6) return SPV_SUCCESS; + + const uint32_t mask = inst->word(6); + if (spv_result_t result = + ValidateImageOperands(_, inst, info, mask, /* word_index = */ 7)) + return result; + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageRead(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + uint32_t actual_result_type = 0; + if (spv_result_t error = GetActualResultType(_, inst, &actual_result_type)) { + return error; + } + + if (!_.IsIntScalarOrVectorType(actual_result_type) && + !_.IsFloatScalarOrVectorType(actual_result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to be int or float scalar or vector type"; + } + +#if 0 + // TODO(atgoo@github.com) Disabled until the spec is clarified. + if (_.GetDimension(actual_result_type) != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to have 4 components"; + } +#endif + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (info.dim == SpvDimSubpassData) { + if (opcode == SpvOpImageSparseRead) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Dim SubpassData cannot be used with ImageSparseRead"; + } + + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + std::string("Dim SubpassData requires Fragment execution model: ") + + spvOpcodeString(opcode)); + } + + if (_.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { + const uint32_t result_component_type = + _.GetComponentType(actual_result_type); + if (result_component_type != info.sampled_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled Type' to be the same as " + << GetActualResultTypeStr(opcode) << " components"; + } + } + + if (spv_result_t result = ValidateImageCommon(_, inst, info)) return result; + + const uint32_t coord_type = _.GetOperandTypeId(inst, 3); + if (!_.IsIntScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be int scalar or vector"; + } + + const uint32_t min_coord_size = GetMinCoordSize(opcode, info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + + if (info.format == SpvImageFormatUnknown && info.dim != SpvDimSubpassData && + !_.HasCapability(SpvCapabilityStorageImageReadWithoutFormat)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability StorageImageReadWithoutFormat is required to " + << "read storage image"; + } + + if (inst->words().size() <= 5) return SPV_SUCCESS; + + const uint32_t mask = inst->word(5); + if (spv_result_t result = + ValidateImageOperands(_, inst, info, mask, /* word_index = */ 6)) + return result; + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageWrite(ValidationState_t& _, const Instruction* inst) { + const uint32_t image_type = _.GetOperandTypeId(inst, 0); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (info.dim == SpvDimSubpassData) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image 'Dim' cannot be SubpassData"; + } + + if (spv_result_t result = ValidateImageCommon(_, inst, info)) return result; + + const uint32_t coord_type = _.GetOperandTypeId(inst, 1); + if (!_.IsIntScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be int scalar or vector"; + } + + const uint32_t min_coord_size = GetMinCoordSize(inst->opcode(), info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + + // TODO(atgoo@github.com) The spec doesn't explicitely say what the type + // of texel should be. + const uint32_t texel_type = _.GetOperandTypeId(inst, 2); + if (!_.IsIntScalarOrVectorType(texel_type) && + !_.IsFloatScalarOrVectorType(texel_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Texel to be int or float vector or scalar"; + } + +#if 0 + // TODO: See above. + if (_.GetDimension(texel_type) != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Texel to have 4 components"; + } +#endif + + if (_.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { + const uint32_t texel_component_type = _.GetComponentType(texel_type); + if (texel_component_type != info.sampled_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled Type' to be the same as Texel " + << "components"; + } + } + + if (info.format == SpvImageFormatUnknown && info.dim != SpvDimSubpassData && + !_.HasCapability(SpvCapabilityStorageImageWriteWithoutFormat)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability StorageImageWriteWithoutFormat is required to " + "write " + << "to storage image"; + } + + if (inst->words().size() <= 4) return SPV_SUCCESS; + + const uint32_t mask = inst->word(4); + if (spv_result_t result = + ValidateImageOperands(_, inst, info, mask, /* word_index = */ 5)) + return result; + + return SPV_SUCCESS; +} + +spv_result_t ValidateImage(ValidationState_t& _, const Instruction* inst) { + const uint32_t result_type = inst->type_id(); + if (_.GetIdOpcode(result_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be OpTypeImage"; + } + + const uint32_t sampled_image_type = _.GetOperandTypeId(inst, 2); + const Instruction* sampled_image_type_inst = _.FindDef(sampled_image_type); + assert(sampled_image_type_inst); + + if (sampled_image_type_inst->opcode() != SpvOpTypeSampledImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sample Image to be of type OpTypeSampleImage"; + } + + if (sampled_image_type_inst->word(2) != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sample Image image type to be equal to Result Type"; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageQuerySizeLod(ValidationState_t& _, + const Instruction* inst) { + const uint32_t result_type = inst->type_id(); + if (!_.IsIntScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be int scalar or vector type"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + uint32_t expected_num_components = info.arrayed; + switch (info.dim) { + case SpvDim1D: + expected_num_components += 1; + break; + case SpvDim2D: + case SpvDimCube: + expected_num_components += 2; + break; + case SpvDim3D: + expected_num_components += 3; + break; + default: + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image 'Dim' must be 1D, 2D, 3D or Cube"; + } + + if (info.multisampled != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Image 'MS' must be 0"; + } + + uint32_t result_num_components = _.GetDimension(result_type); + if (result_num_components != expected_num_components) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result Type has " << result_num_components << " components, " + << "but " << expected_num_components << " expected"; + } + + const uint32_t lod_type = _.GetOperandTypeId(inst, 3); + if (!_.IsIntScalarType(lod_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Level of Detail to be int scalar"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateImageQuerySize(ValidationState_t& _, + const Instruction* inst) { + const uint32_t result_type = inst->type_id(); + if (!_.IsIntScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be int scalar or vector type"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + +#if 0 + // TODO(atgoo@github.com) The spec doesn't whitelist all Dims supported by + // GLSL. Need to verify if there is an error and reenable. + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + uint32_t expected_num_components = info.arrayed; + switch (info.dim) { + case SpvDimBuffer: + expected_num_components += 1; + break; + case SpvDim2D: + if (info.multisampled != 1 && info.sampled != 0 && + info.sampled != 2) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected either 'MS'=1 or 'Sampled'=0 or 'Sampled'=2 " + << "for 2D dim"; + } + case SpvDimRect: + expected_num_components += 2; + break; + case SpvDim3D: + expected_num_components += 3; + if (info.sampled != 0 && + info.sampled != 2) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected either 'Sampled'=0 or 'Sampled'=2 " + << "for 3D dim"; + } + break; + default: + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image 'Dim' must be Buffer, 2D, 3D or Rect"; + } + + + if (info.multisampled != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image 'MS' must be 0"; + } + + uint32_t result_num_components = _.GetDimension(result_type); + if (result_num_components != expected_num_components) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result Type has " << result_num_components << " components, " + << "but " << expected_num_components << " expected"; + } +#endif + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageQueryFormatOrOrder(ValidationState_t& _, + const Instruction* inst) { + if (!_.IsIntScalarType(inst->type_id())) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be int scalar type"; + } + + if (_.GetIdOpcode(_.GetOperandTypeId(inst, 2)) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected operand to be of type OpTypeImage"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateImageQueryLod(ValidationState_t& _, + const Instruction* inst) { + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + "OpImageQueryLod requires Fragment execution model"); + + const uint32_t result_type = inst->type_id(); + if (!_.IsFloatVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be float vector type"; + } + + if (_.GetDimension(result_type) != 2) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to have 2 components"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeSampledImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image operand to be of type OpTypeSampledImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && + info.dim != SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image 'Dim' must be 1D, 2D, 3D or Cube"; + } + + const uint32_t coord_type = _.GetOperandTypeId(inst, 3); + if (_.HasCapability(SpvCapabilityKernel)) { + if (!_.IsFloatScalarOrVectorType(coord_type) && + !_.IsIntScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be int or float scalar or vector"; + } + } else { + if (!_.IsFloatScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be float scalar or vector"; + } + } + + const uint32_t min_coord_size = GetPlaneCoordSize(info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateImageSparseLod(ValidationState_t& _, + const Instruction* inst) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Instruction reserved for future use, use of this instruction " + << "is invalid"; +} + +spv_result_t ValidateImageQueryLevelsOrSamples(ValidationState_t& _, + const Instruction* inst) { + if (!_.IsIntScalarType(inst->type_id())) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be int scalar type"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + const SpvOp opcode = inst->opcode(); + if (opcode == SpvOpImageQueryLevels) { + if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && + info.dim != SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image 'Dim' must be 1D, 2D, 3D or Cube"; + } + } else { + assert(opcode == SpvOpImageQuerySamples); + if (info.dim != SpvDim2D) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Image 'Dim' must be 2D"; + } + + if (info.multisampled != 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Image 'MS' must be 1"; + } + } + return SPV_SUCCESS; +} + +spv_result_t ValidateImageSparseTexelsResident(ValidationState_t& _, + const Instruction* inst) { + if (!_.IsBoolScalarType(inst->type_id())) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be bool scalar type"; + } + + const uint32_t resident_code_type = _.GetOperandTypeId(inst, 2); + if (!_.IsIntScalarType(resident_code_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Resident Code to be int scalar"; + } + + return SPV_SUCCESS; +} + +} // namespace + +// Validates correctness of image instructions. +spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + if (IsImplicitLod(opcode)) { + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + "ImplicitLod instructions require Fragment execution model"); + } + + switch (opcode) { + case SpvOpTypeImage: + return ValidateTypeImage(_, inst); + case SpvOpTypeSampledImage: + return ValidateTypeSampledImage(_, inst); + case SpvOpSampledImage: + return ValidateSampledImage(_, inst); + + case SpvOpImageSampleImplicitLod: + case SpvOpImageSampleExplicitLod: + case SpvOpImageSampleProjImplicitLod: + case SpvOpImageSampleProjExplicitLod: + case SpvOpImageSparseSampleImplicitLod: + case SpvOpImageSparseSampleExplicitLod: + return ValidateImageLod(_, inst); + + case SpvOpImageSampleDrefImplicitLod: + case SpvOpImageSampleDrefExplicitLod: + case SpvOpImageSampleProjDrefImplicitLod: + case SpvOpImageSampleProjDrefExplicitLod: + case SpvOpImageSparseSampleDrefImplicitLod: + case SpvOpImageSparseSampleDrefExplicitLod: + return ValidateImageDrefLod(_, inst); + + case SpvOpImageFetch: + case SpvOpImageSparseFetch: + return ValidateImageFetch(_, inst); + + case SpvOpImageGather: + case SpvOpImageDrefGather: + case SpvOpImageSparseGather: + case SpvOpImageSparseDrefGather: + return ValidateImageGather(_, inst); + + case SpvOpImageRead: + case SpvOpImageSparseRead: + return ValidateImageRead(_, inst); + + case SpvOpImageWrite: + return ValidateImageWrite(_, inst); + + case SpvOpImage: + return ValidateImage(_, inst); + + case SpvOpImageQueryFormat: + case SpvOpImageQueryOrder: + return ValidateImageQueryFormatOrOrder(_, inst); + + case SpvOpImageQuerySizeLod: + return ValidateImageQuerySizeLod(_, inst); + case SpvOpImageQuerySize: + return ValidateImageQuerySize(_, inst); + case SpvOpImageQueryLod: + return ValidateImageQueryLod(_, inst); + + case SpvOpImageQueryLevels: + case SpvOpImageQuerySamples: + return ValidateImageQueryLevelsOrSamples(_, inst); + + case SpvOpImageSparseSampleProjImplicitLod: + case SpvOpImageSparseSampleProjExplicitLod: + case SpvOpImageSparseSampleProjDrefImplicitLod: + case SpvOpImageSparseSampleProjDrefExplicitLod: + return ValidateImageSparseLod(_, inst); + + case SpvOpImageSparseTexelsResident: + return ValidateImageSparseTexelsResident(_, inst); + + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_instruction.cpp b/3rdparty/spirv-tools/source/val/validate_instruction.cpp similarity index 59% rename from 3rdparty/spirv-tools/source/validate_instruction.cpp rename to 3rdparty/spirv-tools/source/val/validate_instruction.cpp index a50e0ea89..85995caa3 100644 --- a/3rdparty/spirv-tools/source/validate_instruction.cpp +++ b/3rdparty/spirv-tools/source/val/validate_instruction.cpp @@ -14,35 +14,31 @@ // Performs validation on instructions that appear inside of a SPIR-V block. -#include "validate.h" +#include "source/val/validate.h" #include #include - #include #include +#include -#include "binary.h" -#include "diagnostic.h" -#include "enum_set.h" -#include "enum_string_mapping.h" -#include "extensions.h" -#include "opcode.h" -#include "operand.h" -#include "spirv_constant.h" -#include "spirv_definition.h" -#include "spirv_target_env.h" -#include "spirv_validator_options.h" -#include "util/string_utils.h" -#include "val/function.h" -#include "val/validation_state.h" - -using libspirv::AssemblyGrammar; -using libspirv::CapabilitySet; -using libspirv::DiagnosticStream; -using libspirv::ExtensionSet; -using libspirv::ValidationState_t; +#include "source/binary.h" +#include "source/diagnostic.h" +#include "source/enum_set.h" +#include "source/enum_string_mapping.h" +#include "source/extensions.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/spirv_constant.h" +#include "source/spirv_definition.h" +#include "source/spirv_target_env.h" +#include "source/spirv_validator_options.h" +#include "source/util/string_utils.h" +#include "source/val/function.h" +#include "source/val/validation_state.h" +namespace spvtools { +namespace val { namespace { std::string ToString(const CapabilitySet& capabilities, @@ -59,16 +55,6 @@ std::string ToString(const CapabilitySet& capabilities, return ss.str(); } -// Reports a missing-capability error to _'s diagnostic stream and returns -// SPV_ERROR_INVALID_CAPABILITY. -spv_result_t CapabilityError(ValidationState_t& _, int which_operand, - SpvOp opcode, - const std::string& required_capabilities) { - return _.diag(SPV_ERROR_INVALID_CAPABILITY) - << "Operand " << which_operand << " of " << spvOpcodeString(opcode) - << " requires one of these capabilities: " << required_capabilities; -} - // Returns capabilities that enable an opcode. An empty result is interpreted // as no prohibition of use of the opcode. If the result is non-empty, then // the opcode may only be used if at least one of the capabilities is specified @@ -86,8 +72,7 @@ CapabilitySet EnablingCapabilitiesForOp(const ValidationState_t& state, case SpvOpGroupFMaxNonUniformAMD: case SpvOpGroupUMaxNonUniformAMD: case SpvOpGroupSMaxNonUniformAMD: - if (state.HasExtension(libspirv::kSPV_AMD_shader_ballot)) - return CapabilitySet(); + if (state.HasExtension(kSPV_AMD_shader_ballot)) return CapabilitySet(); break; default: break; @@ -101,9 +86,14 @@ CapabilitySet EnablingCapabilitiesForOp(const ValidationState_t& state, return CapabilitySet(); } -// Returns an operand's required capabilities. -CapabilitySet RequiredCapabilities(const ValidationState_t& state, - spv_operand_type_t type, uint32_t operand) { +// Returns SPV_SUCCESS if the given operand is enabled by capabilities declared +// in the module. Otherwise issues an error message and returns +// SPV_ERROR_INVALID_CAPABILITY. +spv_result_t CheckRequiredCapabilities(const ValidationState_t& state, + const Instruction* inst, + size_t which_operand, + spv_operand_type_t type, + uint32_t operand) { // Mere mention of PointSize, ClipDistance, or CullDistance in a Builtin // decoration does not require the associated capability. The use of such // a variable value should trigger the capability requirement, but that's @@ -114,47 +104,54 @@ CapabilitySet RequiredCapabilities(const ValidationState_t& state, case SpvBuiltInPointSize: case SpvBuiltInClipDistance: case SpvBuiltInCullDistance: - return CapabilitySet(); + return SPV_SUCCESS; default: break; } } else if (type == SPV_OPERAND_TYPE_FP_ROUNDING_MODE) { // Allow all FP rounding modes if requested if (state.features().free_fp_rounding_mode) { - return CapabilitySet(); + return SPV_SUCCESS; } + } else if (type == SPV_OPERAND_TYPE_GROUP_OPERATION && + state.features().group_ops_reduce_and_scans && + (operand <= uint32_t(SpvGroupOperationExclusiveScan))) { + // Allow certain group operations if requested. + return SPV_SUCCESS; } - spv_operand_desc operand_desc; - const auto ret = state.grammar().lookupOperand(type, operand, &operand_desc); - if (ret == SPV_SUCCESS) { + CapabilitySet enabling_capabilities; + spv_operand_desc operand_desc = nullptr; + const auto lookup_result = + state.grammar().lookupOperand(type, operand, &operand_desc); + if (lookup_result == SPV_SUCCESS) { // Allow FPRoundingMode decoration if requested. if (type == SPV_OPERAND_TYPE_DECORATION && operand_desc->value == SpvDecorationFPRoundingMode) { - if (state.features().free_fp_rounding_mode) return CapabilitySet(); + if (state.features().free_fp_rounding_mode) return SPV_SUCCESS; // Vulkan API requires more capabilities on rounding mode. if (spvIsVulkanEnv(state.context()->target_env)) { - CapabilitySet cap_set; - cap_set.Add(SpvCapabilityStorageUniformBufferBlock16); - cap_set.Add(SpvCapabilityStorageUniform16); - cap_set.Add(SpvCapabilityStoragePushConstant16); - cap_set.Add(SpvCapabilityStorageInputOutput16); - return cap_set; + enabling_capabilities.Add(SpvCapabilityStorageUniformBufferBlock16); + enabling_capabilities.Add(SpvCapabilityStorageUniform16); + enabling_capabilities.Add(SpvCapabilityStoragePushConstant16); + enabling_capabilities.Add(SpvCapabilityStorageInputOutput16); } - } - // Allow certain group operations if requested. - if (state.features().group_ops_reduce_and_scans && - type == SPV_OPERAND_TYPE_GROUP_OPERATION && - (operand <= uint32_t(SpvGroupOperationExclusiveScan))) { - return CapabilitySet(); + } else { + enabling_capabilities = state.grammar().filterCapsAgainstTargetEnv( + operand_desc->capabilities, operand_desc->numCapabilities); } - return state.grammar().filterCapsAgainstTargetEnv( - operand_desc->capabilities, operand_desc->numCapabilities); + if (!state.HasAnyOfCapabilities(enabling_capabilities)) { + return state.diag(SPV_ERROR_INVALID_CAPABILITY, inst) + << "Operand " << which_operand << " of " + << spvOpcodeString(inst->opcode()) + << " requires one of these capabilities: " + << ToString(enabling_capabilities, state.grammar()); + } } - return CapabilitySet(); + return SPV_SUCCESS; } // Returns operand's required extensions. @@ -175,32 +172,68 @@ ExtensionSet RequiredExtensions(const ValidationState_t& state, return {}; } -} // namespace +// Returns SPV_ERROR_INVALID_BINARY and emits a diagnostic if the instruction +// is explicitly reserved in the SPIR-V core spec. Otherwise return +// SPV_SUCCESS. +spv_result_t ReservedCheck(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + switch (opcode) { + // These instructions are enabled by a capability, but should never + // be used anyway. + case SpvOpImageSparseSampleProjImplicitLod: + case SpvOpImageSparseSampleProjExplicitLod: + case SpvOpImageSparseSampleProjDrefImplicitLod: + case SpvOpImageSparseSampleProjDrefExplicitLod: { + spv_opcode_desc inst_desc; + _.grammar().lookupOpcode(opcode, &inst_desc); + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "Invalid Opcode name 'Op" << inst_desc->name << "'"; + } + default: + break; + } + return SPV_SUCCESS; +} -namespace libspirv { +// Returns SPV_ERROR_INVALID_BINARY and emits a diagnostic if the instruction +// is invalid because of an execution environment constraint. +spv_result_t EnvironmentCheck(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + switch (opcode) { + case SpvOpUndef: + if (_.features().bans_op_undef) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "OpUndef is disallowed"; + } + break; + default: + break; + } + return SPV_SUCCESS; +} -spv_result_t CapabilityCheck(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); +// Returns SPV_ERROR_INVALID_CAPABILITY and emits a diagnostic if the +// instruction is invalid because the required capability isn't declared +// in the module. +spv_result_t CapabilityCheck(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); CapabilitySet opcode_caps = EnablingCapabilitiesForOp(_, opcode); if (!_.HasAnyOfCapabilities(opcode_caps)) { - return _.diag(SPV_ERROR_INVALID_CAPABILITY) + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) << "Opcode " << spvOpcodeString(opcode) << " requires one of these capabilities: " << ToString(opcode_caps, _.grammar()); } - for (int i = 0; i < inst->num_operands; ++i) { - const auto& operand = inst->operands[i]; - const auto word = inst->words[operand.offset]; + for (size_t i = 0; i < inst->operands().size(); ++i) { + const auto& operand = inst->operand(i); + const auto word = inst->word(operand.offset); if (spvOperandIsConcreteMask(operand.type)) { // Check for required capabilities for each bit position of the mask. for (uint32_t mask_bit = 0x80000000; mask_bit; mask_bit >>= 1) { if (word & mask_bit) { - const auto caps = RequiredCapabilities(_, operand.type, mask_bit); - if (!_.HasAnyOfCapabilities(caps)) { - return CapabilityError(_, i + 1, opcode, - ToString(caps, _.grammar())); - } + spv_result_t status = + CheckRequiredCapabilities(_, inst, i + 1, operand.type, mask_bit); + if (status != SPV_SUCCESS) return status; } } } else if (spvIsIdType(operand.type)) { @@ -209,10 +242,9 @@ spv_result_t CapabilityCheck(ValidationState_t& _, // https://github.com/KhronosGroup/SPIRV-Tools/issues/248 } else { // Check the operand word as a whole. - const auto caps = RequiredCapabilities(_, operand.type, word); - if (!_.HasAnyOfCapabilities(caps)) { - return CapabilityError(_, i + 1, opcode, ToString(caps, _.grammar())); - } + spv_result_t status = + CheckRequiredCapabilities(_, inst, i + 1, operand.type, word); + if (status != SPV_SUCCESS) return status; } } return SPV_SUCCESS; @@ -220,65 +252,70 @@ spv_result_t CapabilityCheck(ValidationState_t& _, // Checks that all extensions required by the given instruction's operands were // declared in the module. -spv_result_t ExtensionCheck(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - for (size_t operand_index = 0; operand_index < inst->num_operands; +spv_result_t ExtensionCheck(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + for (size_t operand_index = 0; operand_index < inst->operands().size(); ++operand_index) { - const auto& operand = inst->operands[operand_index]; - const uint32_t word = inst->words[operand.offset]; + const auto& operand = inst->operand(operand_index); + const uint32_t word = inst->word(operand.offset); const ExtensionSet required_extensions = RequiredExtensions(_, operand.type, word); if (!_.HasAnyOfExtensions(required_extensions)) { - return _.diag(SPV_ERROR_MISSING_EXTENSION) - << spvutils::CardinalToOrdinal(operand_index + 1) << " operand of " - << spvOpcodeString(opcode) << ": operand " << word - << " requires one of these extensions: " + return _.diag(SPV_ERROR_MISSING_EXTENSION, inst) + << spvtools::utils::CardinalToOrdinal(operand_index + 1) + << " operand of " << spvOpcodeString(opcode) << ": operand " + << word << " requires one of these extensions: " << ExtensionSetToString(required_extensions); } } return SPV_SUCCESS; } -// Checks that the instruction can be used in this target environment. -spv_result_t VersionCheck(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const auto opcode = static_cast(inst->opcode); +// Checks that the instruction can be used in this target environment's base +// version. Assumes that CapabilityCheck has checked direct capability +// dependencies for the opcode. +spv_result_t VersionCheck(ValidationState_t& _, const Instruction* inst) { + const auto opcode = inst->opcode(); spv_opcode_desc inst_desc; - const bool r = _.grammar().lookupOpcode(opcode, &inst_desc); + const spv_result_t r = _.grammar().lookupOpcode(opcode, &inst_desc); assert(r == SPV_SUCCESS); (void)r; const auto min_version = inst_desc->minVersion; + if (inst_desc->numCapabilities > 0u) { + // We already checked that the direct capability dependency has been + // satisfied. We don't need to check any further. + return SPV_SUCCESS; + } + ExtensionSet exts(inst_desc->numExtensions, inst_desc->extensions); if (exts.IsEmpty()) { // If no extensions can enable this instruction, then emit error messages // only concerning core SPIR-V versions if errors happen. if (min_version == ~0u) { - return _.diag(SPV_ERROR_WRONG_VERSION) + return _.diag(SPV_ERROR_WRONG_VERSION, inst) << spvOpcodeString(opcode) << " is reserved for future use."; } if (spvVersionForTargetEnv(_.grammar().target_env()) < min_version) { - return _.diag(SPV_ERROR_WRONG_VERSION) + return _.diag(SPV_ERROR_WRONG_VERSION, inst) << spvOpcodeString(opcode) << " requires " << spvTargetEnvDescription( static_cast(min_version)) << " at minimum."; } - } // Otherwise, we only error out when no enabling extensions are registered. - else if (!_.HasAnyOfExtensions(exts)) { + } else if (!_.HasAnyOfExtensions(exts)) { if (min_version == ~0u) { - return _.diag(SPV_ERROR_MISSING_EXTENSION) + return _.diag(SPV_ERROR_MISSING_EXTENSION, inst) << spvOpcodeString(opcode) << " requires one of the following extensions: " << ExtensionSetToString(exts); } if (static_cast(_.grammar().target_env()) < min_version) { - return _.diag(SPV_ERROR_WRONG_VERSION) + return _.diag(SPV_ERROR_WRONG_VERSION, inst) << spvOpcodeString(opcode) << " requires " << spvTargetEnvDescription( static_cast(min_version)) @@ -291,20 +328,18 @@ spv_result_t VersionCheck(ValidationState_t& _, } // Checks that the Resuld is within the valid bound. -spv_result_t LimitCheckIdBound(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - if (inst->result_id >= _.getIdBound()) { - return _.diag(SPV_ERROR_INVALID_BINARY) - << "Result '" << inst->result_id +spv_result_t LimitCheckIdBound(ValidationState_t& _, const Instruction* inst) { + if (inst->id() >= _.getIdBound()) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "Result '" << inst->id() << "' must be less than the ID bound '" << _.getIdBound() << "'."; } return SPV_SUCCESS; } // Checks that the number of OpTypeStruct members is within the limit. -spv_result_t LimitCheckStruct(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - if (SpvOpTypeStruct != inst->opcode) { +spv_result_t LimitCheckStruct(ValidationState_t& _, const Instruction* inst) { + if (SpvOpTypeStruct != inst->opcode()) { return SPV_SUCCESS; } @@ -312,9 +347,9 @@ spv_result_t LimitCheckStruct(ValidationState_t& _, // One operand is the result ID. const uint16_t limit = static_cast(_.options()->universal_limits_.max_struct_members); - if (inst->num_operands - 1 > limit) { - return _.diag(SPV_ERROR_INVALID_BINARY) - << "Number of OpTypeStruct members (" << inst->num_operands - 1 + if (inst->operands().size() - 1 > limit) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "Number of OpTypeStruct members (" << inst->operands().size() - 1 << ") has exceeded the limit (" << limit << ")."; } @@ -327,8 +362,8 @@ spv_result_t LimitCheckStruct(ValidationState_t& _, // Scalars are at depth 0. uint32_t max_member_depth = 0; // Struct members start at word 2 of OpTypeStruct instruction. - for (size_t word_i = 2; word_i < inst->num_words; ++word_i) { - auto member = inst->words[word_i]; + for (size_t word_i = 2; word_i < inst->words().size(); ++word_i) { + auto member = inst->word(word_i); auto memberTypeInstr = _.FindDef(member); if (memberTypeInstr && SpvOpTypeStruct == memberTypeInstr->opcode()) { max_member_depth = std::max( @@ -338,9 +373,9 @@ spv_result_t LimitCheckStruct(ValidationState_t& _, const uint32_t depth_limit = _.options()->universal_limits_.max_struct_depth; const uint32_t cur_depth = 1 + max_member_depth; - _.set_struct_nesting_depth(inst->result_id, cur_depth); + _.set_struct_nesting_depth(inst->id(), cur_depth); if (cur_depth > depth_limit) { - return _.diag(SPV_ERROR_INVALID_BINARY) + return _.diag(SPV_ERROR_INVALID_BINARY, inst) << "Structure Nesting Depth may not be larger than " << depth_limit << ". Found " << cur_depth << "."; } @@ -349,18 +384,17 @@ spv_result_t LimitCheckStruct(ValidationState_t& _, // Checks that the number of (literal, label) pairs in OpSwitch is within the // limit. -spv_result_t LimitCheckSwitch(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - if (SpvOpSwitch == inst->opcode) { +spv_result_t LimitCheckSwitch(ValidationState_t& _, const Instruction* inst) { + if (SpvOpSwitch == inst->opcode()) { // The instruction syntax is as follows: // OpSwitch literal label literal label ... // literal,label pairs come after the first 2 operands. // It is guaranteed at this point that num_operands is an even numner. - unsigned int num_pairs = (inst->num_operands - 2) / 2; + size_t num_pairs = (inst->operands().size() - 2) / 2; const unsigned int num_pairs_limit = _.options()->universal_limits_.max_switch_branches; if (num_pairs > num_pairs_limit) { - return _.diag(SPV_ERROR_INVALID_BINARY) + return _.diag(SPV_ERROR_INVALID_BINARY, inst) << "Number of (literal, label) pairs in OpSwitch (" << num_pairs << ") exceeds the limit (" << num_pairs_limit << ")."; } @@ -376,7 +410,7 @@ spv_result_t LimitCheckNumVars(ValidationState_t& _, const uint32_t var_id, const uint32_t num_local_vars_limit = _.options()->universal_limits_.max_local_variables; if (_.num_local_vars() > num_local_vars_limit) { - return _.diag(SPV_ERROR_INVALID_BINARY) + return _.diag(SPV_ERROR_INVALID_BINARY, nullptr) << "Number of local variables ('Function' Storage Class) " "exceeded the valid limit (" << num_local_vars_limit << ")."; @@ -386,7 +420,7 @@ spv_result_t LimitCheckNumVars(ValidationState_t& _, const uint32_t var_id, const uint32_t num_global_vars_limit = _.options()->universal_limits_.max_global_variables; if (_.num_global_vars() > num_global_vars_limit) { - return _.diag(SPV_ERROR_INVALID_BINARY) + return _.diag(SPV_ERROR_INVALID_BINARY, nullptr) << "Number of Global Variables (Storage Class other than " "'Function') exceeded the valid limit (" << num_global_vars_limit << ")."; @@ -398,27 +432,27 @@ spv_result_t LimitCheckNumVars(ValidationState_t& _, const uint32_t var_id, // Registers necessary decoration(s) for the appropriate IDs based on the // instruction. spv_result_t RegisterDecorations(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - switch (inst->opcode) { + const Instruction* inst) { + switch (inst->opcode()) { case SpvOpDecorate: { - const uint32_t target_id = inst->words[1]; - const SpvDecoration dec_type = static_cast(inst->words[2]); + const uint32_t target_id = inst->word(1); + const SpvDecoration dec_type = static_cast(inst->word(2)); std::vector dec_params; - if (inst->num_words > 3) { - dec_params.insert(dec_params.end(), inst->words + 3, - inst->words + inst->num_words); + if (inst->words().size() > 3) { + dec_params.insert(dec_params.end(), inst->words().begin() + 3, + inst->words().end()); } _.RegisterDecorationForId(target_id, Decoration(dec_type, dec_params)); break; } case SpvOpMemberDecorate: { - const uint32_t struct_id = inst->words[1]; - const uint32_t index = inst->words[2]; - const SpvDecoration dec_type = static_cast(inst->words[3]); + const uint32_t struct_id = inst->word(1); + const uint32_t index = inst->word(2); + const SpvDecoration dec_type = static_cast(inst->word(3)); std::vector dec_params; - if (inst->num_words > 4) { - dec_params.insert(dec_params.end(), inst->words + 4, - inst->words + inst->num_words); + if (inst->words().size() > 4) { + dec_params.insert(dec_params.end(), inst->words().begin() + 4, + inst->words().end()); } _.RegisterDecorationForId(struct_id, Decoration(dec_type, dec_params, index)); @@ -432,11 +466,11 @@ spv_result_t RegisterDecorations(ValidationState_t& _, case SpvOpGroupDecorate: { // Word 1 is the group . All subsequent words are target s that // are going to be decorated with the decorations. - const uint32_t decoration_group_id = inst->words[1]; + const uint32_t decoration_group_id = inst->word(1); std::vector& group_decorations = _.id_decorations(decoration_group_id); - for (int i = 2; i < inst->num_words; ++i) { - const uint32_t target_id = inst->words[i]; + for (size_t i = 2; i < inst->words().size(); ++i) { + const uint32_t target_id = inst->word(i); _.RegisterDecorationsForId(target_id, group_decorations.begin(), group_decorations.end()); } @@ -446,14 +480,14 @@ spv_result_t RegisterDecorations(ValidationState_t& _, // Word 1 is the Decoration Group followed by (struct,literal) // pairs. All decorations of the group should be applied to all the struct // members that are specified in the instructions. - const uint32_t decoration_group_id = inst->words[1]; + const uint32_t decoration_group_id = inst->word(1); std::vector& group_decorations = _.id_decorations(decoration_group_id); // Grammar checks ensures that the number of arguments to this instruction // is an odd number: 1 decoration group + (id,literal) pairs. - for (int i = 2; i + 1 < inst->num_words; i = i + 2) { - const uint32_t struct_id = inst->words[i]; - const uint32_t index = inst->words[i + 1]; + for (size_t i = 2; i + 1 < inst->words().size(); i = i + 2) { + const uint32_t struct_id = inst->word(i); + const uint32_t index = inst->word(i + 1); // ID validation phase ensures this is in fact a struct instruction and // that the index is not out of bound. _.RegisterDecorationsForStructMember(struct_id, index, @@ -469,57 +503,59 @@ spv_result_t RegisterDecorations(ValidationState_t& _, } // Parses OpExtension instruction and logs warnings if unsuccessful. -void CheckIfKnownExtension(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const std::string extension_str = GetExtensionString(inst); +void CheckIfKnownExtension(ValidationState_t& _, const Instruction* inst) { + const std::string extension_str = GetExtensionString(&(inst->c_inst())); Extension extension; if (!GetExtensionFromString(extension_str.c_str(), &extension)) { - _.diag(SPV_SUCCESS) << "Found unrecognized extension " << extension_str; + _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "Found unrecognized extension " << extension_str; return; } } -spv_result_t InstructionPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); +} // namespace + +spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); if (opcode == SpvOpExtension) { CheckIfKnownExtension(_, inst); } else if (opcode == SpvOpCapability) { - _.RegisterCapability( - static_cast(inst->words[inst->operands[0].offset])); + _.RegisterCapability(inst->GetOperandAs(0)); } else if (opcode == SpvOpMemoryModel) { - _.set_addressing_model( - static_cast(inst->words[inst->operands[0].offset])); - _.set_memory_model( - static_cast(inst->words[inst->operands[1].offset])); + if (_.has_memory_model_specified()) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "OpMemoryModel should only be provided once."; + } + _.set_addressing_model(inst->GetOperandAs(0)); + _.set_memory_model(inst->GetOperandAs(1)); } else if (opcode == SpvOpExecutionMode) { - const uint32_t entry_point = inst->words[1]; + const uint32_t entry_point = inst->word(1); _.RegisterExecutionModeForEntryPoint(entry_point, - SpvExecutionMode(inst->words[2])); + SpvExecutionMode(inst->word(2))); } else if (opcode == SpvOpVariable) { - const auto storage_class = - static_cast(inst->words[inst->operands[2].offset]); - if (auto error = LimitCheckNumVars(_, inst->result_id, storage_class)) { + const auto storage_class = inst->GetOperandAs(2); + if (auto error = LimitCheckNumVars(_, inst->id(), storage_class)) { return error; } if (storage_class == SpvStorageClassGeneric) - return _.diag(SPV_ERROR_INVALID_BINARY) + return _.diag(SPV_ERROR_INVALID_BINARY, inst) << "OpVariable storage class cannot be Generic"; if (_.current_layout_section() == kLayoutFunctionDefinitions) { if (storage_class != SpvStorageClassFunction) { - return _.diag(SPV_ERROR_INVALID_LAYOUT) + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << "Variables must have a function[7] storage class inside" " of a function"; } if (_.current_function().IsFirstBlock( _.current_function().current_block()->id()) == false) { - return _.diag(SPV_ERROR_INVALID_CFG) << "Variables can only be defined " - "in the first block of a " - "function"; + return _.diag(SPV_ERROR_INVALID_CFG, inst) + << "Variables can only be defined " + "in the first block of a " + "function"; } } else { if (storage_class == SpvStorageClassFunction) { - return _.diag(SPV_ERROR_INVALID_LAYOUT) + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << "Variables can not have a function[7] storage class " "outside of a function"; } @@ -528,11 +564,12 @@ spv_result_t InstructionPass(ValidationState_t& _, // SPIR-V Spec 2.16.3: Validation Rules for Kernel Capabilities: The // Signedness in OpTypeInt must always be 0. - if (SpvOpTypeInt == inst->opcode && _.HasCapability(SpvCapabilityKernel) && - inst->words[inst->operands[2].offset] != 0u) { - return _.diag(SPV_ERROR_INVALID_BINARY) << "The Signedness in OpTypeInt " - "must always be 0 when Kernel " - "capability is used."; + if (SpvOpTypeInt == inst->opcode() && _.HasCapability(SpvCapabilityKernel) && + inst->GetOperandAs(2) != 0u) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "The Signedness in OpTypeInt " + "must always be 0 when Kernel " + "capability is used."; } // In order to validate decoration rules, we need to know all the decorations @@ -540,6 +577,8 @@ spv_result_t InstructionPass(ValidationState_t& _, RegisterDecorations(_, inst); if (auto error = ExtensionCheck(_, inst)) return error; + if (auto error = ReservedCheck(_, inst)) return error; + if (auto error = EnvironmentCheck(_, inst)) return error; if (auto error = CapabilityCheck(_, inst)) return error; if (auto error = LimitCheckIdBound(_, inst)) return error; if (auto error = LimitCheckStruct(_, inst)) return error; @@ -549,4 +588,6 @@ spv_result_t InstructionPass(ValidationState_t& _, // All instruction checks have passed. return SPV_SUCCESS; } -} // namespace libspirv + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_interfaces.cpp b/3rdparty/spirv-tools/source/val/validate_interfaces.cpp new file mode 100644 index 000000000..fffc6da1a --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_interfaces.cpp @@ -0,0 +1,114 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include +#include + +#include "source/diagnostic.h" +#include "source/val/function.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// Returns true if \c inst is an input or output variable. +bool is_interface_variable(const Instruction* inst) { + return inst->opcode() == SpvOpVariable && + (inst->word(3u) == SpvStorageClassInput || + inst->word(3u) == SpvStorageClassOutput); +} + +// Checks that \c var is listed as an interface in all the entry points that use +// it. +spv_result_t check_interface_variable(ValidationState_t& _, + const Instruction* var) { + std::vector functions; + std::vector uses; + for (auto use : var->uses()) { + uses.push_back(use.first); + } + for (uint32_t i = 0; i < uses.size(); ++i) { + const auto user = uses[i]; + if (const Function* func = user->function()) { + functions.push_back(func); + } else { + // In the rare case that the variable is used by another instruction in + // the global scope, continue searching for an instruction used in a + // function. + for (auto use : user->uses()) { + uses.push_back(use.first); + } + } + } + + std::sort(functions.begin(), functions.end(), + [](const Function* lhs, const Function* rhs) { + return lhs->id() < rhs->id(); + }); + functions.erase(std::unique(functions.begin(), functions.end()), + functions.end()); + + std::vector entry_points; + for (const auto func : functions) { + for (auto id : _.FunctionEntryPoints(func->id())) { + entry_points.push_back(id); + } + } + + std::sort(entry_points.begin(), entry_points.end()); + entry_points.erase(std::unique(entry_points.begin(), entry_points.end()), + entry_points.end()); + + for (auto id : entry_points) { + for (const auto& desc : _.entry_point_descriptions(id)) { + bool found = false; + for (auto interface : desc.interfaces) { + if (var->id() == interface) { + found = true; + break; + } + } + if (!found) { + return _.diag(SPV_ERROR_INVALID_ID, var) + << (var->word(3u) == SpvStorageClassInput ? "Input" : "Output") + << " variable id <" << var->id() << "> is used by entry point '" + << desc.name << "' id <" << id + << ">, but is not listed as an interface"; + } + } + } + + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t ValidateInterfaces(ValidationState_t& _) { + for (auto& inst : _.ordered_instructions()) { + if (is_interface_variable(&inst)) { + if (auto error = check_interface_variable(_, &inst)) { + return error; + } + } + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_layout.cpp b/3rdparty/spirv-tools/source/val/validate_layout.cpp similarity index 72% rename from 3rdparty/spirv-tools/source/validate_layout.cpp rename to 3rdparty/spirv-tools/source/val/validate_layout.cpp index 37c49be55..53c28355f 100644 --- a/3rdparty/spirv-tools/source/validate_layout.cpp +++ b/3rdparty/spirv-tools/source/val/validate_layout.cpp @@ -14,37 +14,33 @@ // Source code for logical layout validation as described in section 2.4 -#include "validate.h" +#include "source/val/validate.h" #include -#include "diagnostic.h" -#include "opcode.h" -#include "operand.h" -#include "spirv-tools/libspirv.h" -#include "val/function.h" -#include "val/validation_state.h" - -using libspirv::FunctionDecl; -using libspirv::kLayoutFunctionDeclarations; -using libspirv::kLayoutFunctionDefinitions; -using libspirv::kLayoutMemoryModel; -using libspirv::ValidationState_t; +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/val/function.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" +namespace spvtools { +namespace val { namespace { + // Module scoped instructions are processed by determining if the opcode // is part of the current layout section. If it is not then the next sections is // checked. spv_result_t ModuleScopedInstructions(ValidationState_t& _, - const spv_parsed_instruction_t* inst, - SpvOp opcode) { + const Instruction* inst, SpvOp opcode) { while (_.IsOpcodeInCurrentLayoutSection(opcode) == false) { _.ProgressToNextLayoutSectionOrder(); switch (_.current_layout_section()) { case kLayoutMemoryModel: if (opcode != SpvOpMemoryModel) { - return _.diag(SPV_ERROR_INVALID_LAYOUT) + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << spvOpcodeString(opcode) << " cannot appear before the memory model instruction"; } @@ -52,7 +48,7 @@ spv_result_t ModuleScopedInstructions(ValidationState_t& _, case kLayoutFunctionDeclarations: // All module sections have been processed. Recursively call // ModuleLayoutPass to process the next section of the module - return libspirv::ModuleLayoutPass(_, inst); + return ModuleLayoutPass(_, inst); default: break; } @@ -66,20 +62,18 @@ spv_result_t ModuleScopedInstructions(ValidationState_t& _, // inside of another function. This stage ends when the first label is // encountered inside of a function. spv_result_t FunctionScopedInstructions(ValidationState_t& _, - const spv_parsed_instruction_t* inst, - SpvOp opcode) { + const Instruction* inst, SpvOp opcode) { if (_.IsOpcodeInCurrentLayoutSection(opcode)) { switch (opcode) { case SpvOpFunction: { if (_.in_function_body()) { - return _.diag(SPV_ERROR_INVALID_LAYOUT) + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << "Cannot declare a function in a function body"; } - auto control_mask = static_cast( - inst->words[inst->operands[2].offset]); + auto control_mask = inst->GetOperandAs(2); if (auto error = - _.RegisterFunction(inst->result_id, inst->type_id, control_mask, - inst->words[inst->operands[3].offset])) + _.RegisterFunction(inst->id(), inst->type_id(), control_mask, + inst->GetOperandAs(3))) return error; if (_.current_layout_section() == kLayoutFunctionDefinitions) { if (auto error = _.current_function().RegisterSetFunctionDeclType( @@ -90,34 +84,34 @@ spv_result_t FunctionScopedInstructions(ValidationState_t& _, case SpvOpFunctionParameter: if (_.in_function_body() == false) { - return _.diag(SPV_ERROR_INVALID_LAYOUT) << "Function parameter " - "instructions must be in " - "a function body"; + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Function parameter instructions must be in a " + "function body"; } if (_.current_function().block_count() != 0) { - return _.diag(SPV_ERROR_INVALID_LAYOUT) + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << "Function parameters must only appear immediately after " "the function definition"; } if (auto error = _.current_function().RegisterFunctionParameter( - inst->result_id, inst->type_id)) + inst->id(), inst->type_id())) return error; break; case SpvOpFunctionEnd: if (_.in_function_body() == false) { - return _.diag(SPV_ERROR_INVALID_LAYOUT) + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << "Function end instructions must be in a function body"; } if (_.in_block()) { - return _.diag(SPV_ERROR_INVALID_LAYOUT) + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << "Function end cannot be called in blocks"; } if (_.current_function().block_count() == 0 && _.current_layout_section() == kLayoutFunctionDefinitions) { - return _.diag(SPV_ERROR_INVALID_LAYOUT) << "Function declarations " - "must appear before " - "function definitions."; + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Function declarations must appear before " + "function definitions."; } if (_.current_layout_section() == kLayoutFunctionDeclarations) { if (auto error = _.current_function().RegisterSetFunctionDeclType( @@ -135,11 +129,11 @@ spv_result_t FunctionScopedInstructions(ValidationState_t& _, // definition so set the function to a declaration and update the // module section if (_.in_function_body() == false) { - return _.diag(SPV_ERROR_INVALID_LAYOUT) + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << "Label instructions must be in a function body"; } if (_.in_block()) { - return _.diag(SPV_ERROR_INVALID_LAYOUT) + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << "A block must end with a branch instruction."; } if (_.current_layout_section() == kLayoutFunctionDeclarations) { @@ -153,33 +147,32 @@ spv_result_t FunctionScopedInstructions(ValidationState_t& _, default: if (_.current_layout_section() == kLayoutFunctionDeclarations && _.in_function_body()) { - return _.diag(SPV_ERROR_INVALID_LAYOUT) + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << "A function must begin with a label"; } else { if (_.in_block() == false) { - return _.diag(SPV_ERROR_INVALID_LAYOUT) + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << spvOpcodeString(opcode) << " must appear in a block"; } } break; } } else { - return _.diag(SPV_ERROR_INVALID_LAYOUT) + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << spvOpcodeString(opcode) << " cannot appear in a function declaration"; } return SPV_SUCCESS; } + } // namespace -namespace libspirv { // TODO(umar): Check linkage capabilities for function declarations // TODO(umar): Better error messages // NOTE: This function does not handle CFG related validation // Performs logical layout validation. See Section 2.4 -spv_result_t ModuleLayoutPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); +spv_result_t ModuleLayoutPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); switch (_.current_layout_section()) { case kLayoutCapabilities: @@ -204,4 +197,6 @@ spv_result_t ModuleLayoutPass(ValidationState_t& _, } return SPV_SUCCESS; } -} // namespace libspirv + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_literals.cpp b/3rdparty/spirv-tools/source/val/validate_literals.cpp similarity index 72% rename from 3rdparty/spirv-tools/source/validate_literals.cpp rename to 3rdparty/spirv-tools/source/val/validate_literals.cpp index 5ced89ef5..53aae0767 100644 --- a/3rdparty/spirv-tools/source/validate_literals.cpp +++ b/3rdparty/spirv-tools/source/val/validate_literals.cpp @@ -14,22 +14,22 @@ // Validates literal numbers. -#include "validate.h" +#include "source/val/validate.h" #include -#include "diagnostic.h" -#include "opcode.h" -#include "val/instruction.h" -#include "val/validation_state.h" - -namespace libspirv { +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" +namespace spvtools { +namespace val { namespace { // Returns true if the operand holds a literal number -bool IsLiteralNumber(const spv_parsed_operand_t* operand) { - switch (operand->number_kind) { +bool IsLiteralNumber(const spv_parsed_operand_t& operand) { + switch (operand.number_kind) { case SPV_NUMBER_SIGNED_INT: case SPV_NUMBER_UNSIGNED_INT: case SPV_NUMBER_FLOATING: @@ -64,31 +64,30 @@ bool VerifyUpperBits(uint32_t word, uint32_t width, bool signed_int) { } // namespace // Validates that literal numbers are represented according to the spec -spv_result_t LiteralsPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { +spv_result_t LiteralsPass(ValidationState_t& _, const Instruction* inst) { // For every operand that is a literal number - for (uint16_t i = 0; i < inst->num_operands; i++) { - const spv_parsed_operand_t* operand = inst->operands + i; + for (size_t i = 0; i < inst->operands().size(); i++) { + const spv_parsed_operand_t& operand = inst->operand(i); if (!IsLiteralNumber(operand)) continue; // The upper bits are always in the last word (little-endian) - int last_index = operand->offset + operand->num_words - 1; - const uint32_t upper_word = inst->words[last_index]; + int last_index = operand.offset + operand.num_words - 1; + const uint32_t upper_word = inst->word(last_index); // TODO(jcaraban): is the |word size| defined in some header? const uint32_t word_size = 32; - uint32_t bit_width = operand->number_bit_width; + uint32_t bit_width = operand.number_bit_width; // Bit widths that are a multiple of the word size have no upper bits const auto remaining_value_bits = bit_width % word_size; if (remaining_value_bits == 0) continue; - const bool signedness = operand->number_kind == SPV_NUMBER_SIGNED_INT; + const bool signedness = operand.number_kind == SPV_NUMBER_SIGNED_INT; if (!VerifyUpperBits(upper_word, remaining_value_bits, signedness)) { - return _.diag(SPV_ERROR_INVALID_VALUE) + return _.diag(SPV_ERROR_INVALID_VALUE, inst) << "The high-order bits of a literal number in instruction " - << inst->result_id << " must be 0 for a floating-point type, " + << inst->id() << " must be 0 for a floating-point type, " << "or 0 for an integer type with Signedness of 0, " << "or sign extended when Signedness is 1"; } @@ -96,4 +95,5 @@ spv_result_t LiteralsPass(ValidationState_t& _, return SPV_SUCCESS; } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_logicals.cpp b/3rdparty/spirv-tools/source/val/validate_logicals.cpp similarity index 70% rename from 3rdparty/spirv-tools/source/validate_logicals.cpp rename to 3rdparty/spirv-tools/source/val/validate_logicals.cpp index 5a6c034f0..9c637c423 100644 --- a/3rdparty/spirv-tools/source/validate_logicals.cpp +++ b/3rdparty/spirv-tools/source/val/validate_logicals.cpp @@ -14,53 +14,32 @@ // Validates correctness of logical SPIR-V instructions. -#include "validate.h" +#include "source/val/validate.h" -#include "diagnostic.h" -#include "opcode.h" -#include "val/instruction.h" -#include "val/validation_state.h" +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" -namespace libspirv { - -namespace { - -// Returns operand word for given instruction and operand index. -// The operand is expected to only have one word. -inline uint32_t GetOperandWord(const spv_parsed_instruction_t* inst, - size_t operand_index) { - assert(operand_index < inst->num_operands); - const spv_parsed_operand_t& operand = inst->operands[operand_index]; - assert(operand.num_words == 1); - return inst->words[operand.offset]; -} - -// Returns the type id of instruction operand at |operand_index|. -// The operand is expected to be an id. -inline uint32_t GetOperandTypeId(ValidationState_t& _, - const spv_parsed_instruction_t* inst, - size_t operand_index) { - return _.GetTypeId(GetOperandWord(inst, operand_index)); -} -} // namespace +namespace spvtools { +namespace val { // Validates correctness of logical instructions. -spv_result_t LogicalsPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - const uint32_t result_type = inst->type_id; +spv_result_t LogicalsPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); switch (opcode) { case SpvOpAny: case SpvOpAll: { if (!_.IsBoolScalarType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected bool scalar type as Result Type: " << spvOpcodeString(opcode); - const uint32_t vector_type = GetOperandTypeId(_, inst, 2); + const uint32_t vector_type = _.GetOperandTypeId(inst, 2); if (!vector_type || !_.IsBoolVectorType(vector_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected operand to be vector bool: " << spvOpcodeString(opcode); @@ -73,19 +52,19 @@ spv_result_t LogicalsPass(ValidationState_t& _, case SpvOpIsNormal: case SpvOpSignBitSet: { if (!_.IsBoolScalarType(result_type) && !_.IsBoolVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected bool scalar or vector type as Result Type: " << spvOpcodeString(opcode); - const uint32_t operand_type = GetOperandTypeId(_, inst, 2); + const uint32_t operand_type = _.GetOperandTypeId(inst, 2); if (!operand_type || (!_.IsFloatScalarType(operand_type) && !_.IsFloatVectorType(operand_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected operand to be scalar or vector float: " << spvOpcodeString(opcode); if (_.GetDimension(result_type) != _.GetDimension(operand_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected vector sizes of Result Type and the operand to be " "equal: " << spvOpcodeString(opcode); @@ -109,25 +88,25 @@ spv_result_t LogicalsPass(ValidationState_t& _, case SpvOpOrdered: case SpvOpUnordered: { if (!_.IsBoolScalarType(result_type) && !_.IsBoolVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected bool scalar or vector type as Result Type: " << spvOpcodeString(opcode); - const uint32_t left_operand_type = GetOperandTypeId(_, inst, 2); + const uint32_t left_operand_type = _.GetOperandTypeId(inst, 2); if (!left_operand_type || (!_.IsFloatScalarType(left_operand_type) && !_.IsFloatVectorType(left_operand_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected operands to be scalar or vector float: " << spvOpcodeString(opcode); if (_.GetDimension(result_type) != _.GetDimension(left_operand_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected vector sizes of Result Type and the operands to be " "equal: " << spvOpcodeString(opcode); - if (left_operand_type != GetOperandTypeId(_, inst, 3)) - return _.diag(SPV_ERROR_INVALID_DATA) + if (left_operand_type != _.GetOperandTypeId(inst, 3)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected left and right operands to have the same type: " << spvOpcodeString(opcode); @@ -139,13 +118,13 @@ spv_result_t LogicalsPass(ValidationState_t& _, case SpvOpLogicalOr: case SpvOpLogicalAnd: { if (!_.IsBoolScalarType(result_type) && !_.IsBoolVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected bool scalar or vector type as Result Type: " << spvOpcodeString(opcode); - if (result_type != GetOperandTypeId(_, inst, 2) || - result_type != GetOperandTypeId(_, inst, 3)) - return _.diag(SPV_ERROR_INVALID_DATA) + if (result_type != _.GetOperandTypeId(inst, 2) || + result_type != _.GetOperandTypeId(inst, 3)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected both operands to be of Result Type: " << spvOpcodeString(opcode); @@ -154,12 +133,12 @@ spv_result_t LogicalsPass(ValidationState_t& _, case SpvOpLogicalNot: { if (!_.IsBoolScalarType(result_type) && !_.IsBoolVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected bool scalar or vector type as Result Type: " << spvOpcodeString(opcode); - if (result_type != GetOperandTypeId(_, inst, 2)) - return _.diag(SPV_ERROR_INVALID_DATA) + if (result_type != _.GetOperandTypeId(inst, 2)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected operand to be of Result Type: " << spvOpcodeString(opcode); @@ -177,7 +156,7 @@ spv_result_t LogicalsPass(ValidationState_t& _, case SpvOpTypePointer: { if (!_.features().variable_pointers && !_.features().variable_pointers_storage_buffer) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Using pointers with OpSelect requires capability " << "VariablePointers or VariablePointersStorageBuffer"; break; @@ -195,30 +174,30 @@ spv_result_t LogicalsPass(ValidationState_t& _, } default: { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected scalar or vector type as Result Type: " << spvOpcodeString(opcode); } } } - const uint32_t condition_type = GetOperandTypeId(_, inst, 2); - const uint32_t left_type = GetOperandTypeId(_, inst, 3); - const uint32_t right_type = GetOperandTypeId(_, inst, 4); + const uint32_t condition_type = _.GetOperandTypeId(inst, 2); + const uint32_t left_type = _.GetOperandTypeId(inst, 3); + const uint32_t right_type = _.GetOperandTypeId(inst, 4); if (!condition_type || (!_.IsBoolScalarType(condition_type) && !_.IsBoolVectorType(condition_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected bool scalar or vector type as condition: " << spvOpcodeString(opcode); if (_.GetDimension(condition_type) != dimension) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected vector sizes of Result Type and the condition to be" << " equal: " << spvOpcodeString(opcode); if (result_type != left_type || result_type != right_type) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected both objects to be of Result Type: " << spvOpcodeString(opcode); @@ -236,37 +215,37 @@ spv_result_t LogicalsPass(ValidationState_t& _, case SpvOpSLessThan: case SpvOpSLessThanEqual: { if (!_.IsBoolScalarType(result_type) && !_.IsBoolVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected bool scalar or vector type as Result Type: " << spvOpcodeString(opcode); - const uint32_t left_type = GetOperandTypeId(_, inst, 2); - const uint32_t right_type = GetOperandTypeId(_, inst, 3); + const uint32_t left_type = _.GetOperandTypeId(inst, 2); + const uint32_t right_type = _.GetOperandTypeId(inst, 3); if (!left_type || (!_.IsIntScalarType(left_type) && !_.IsIntVectorType(left_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected operands to be scalar or vector int: " << spvOpcodeString(opcode); if (_.GetDimension(result_type) != _.GetDimension(left_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected vector sizes of Result Type and the operands to be" << " equal: " << spvOpcodeString(opcode); if (!right_type || (!_.IsIntScalarType(right_type) && !_.IsIntVectorType(right_type))) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected operands to be scalar or vector int: " << spvOpcodeString(opcode); if (_.GetDimension(result_type) != _.GetDimension(right_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected vector sizes of Result Type and the operands to be" << " equal: " << spvOpcodeString(opcode); if (_.GetBitWidth(left_type) != _.GetBitWidth(right_type)) - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected both operands to have the same component bit " "width: " << spvOpcodeString(opcode); @@ -281,4 +260,5 @@ spv_result_t LogicalsPass(ValidationState_t& _, return SPV_SUCCESS; } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_memory.cpp b/3rdparty/spirv-tools/source/val/validate_memory.cpp new file mode 100644 index 000000000..18c3cca0f --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_memory.cpp @@ -0,0 +1,590 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include +#include +#include + +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +bool AreLayoutCompatibleStructs(ValidationState_t&, const Instruction*, + const Instruction*); +bool HaveLayoutCompatibleMembers(ValidationState_t&, const Instruction*, + const Instruction*); +bool HaveSameLayoutDecorations(ValidationState_t&, const Instruction*, + const Instruction*); +bool HasConflictingMemberOffsets(const std::vector&, + const std::vector&); + +// Returns true if the two instructions represent structs that, as far as the +// validator can tell, have the exact same data layout. +bool AreLayoutCompatibleStructs(ValidationState_t& _, const Instruction* type1, + const Instruction* type2) { + if (type1->opcode() != SpvOpTypeStruct) { + return false; + } + if (type2->opcode() != SpvOpTypeStruct) { + return false; + } + + if (!HaveLayoutCompatibleMembers(_, type1, type2)) return false; + + return HaveSameLayoutDecorations(_, type1, type2); +} + +// Returns true if the operands to the OpTypeStruct instruction defining the +// types are the same or are layout compatible types. |type1| and |type2| must +// be OpTypeStruct instructions. +bool HaveLayoutCompatibleMembers(ValidationState_t& _, const Instruction* type1, + const Instruction* type2) { + assert(type1->opcode() == SpvOpTypeStruct && + "type1 must be and OpTypeStruct instruction."); + assert(type2->opcode() == SpvOpTypeStruct && + "type2 must be and OpTypeStruct instruction."); + const auto& type1_operands = type1->operands(); + const auto& type2_operands = type2->operands(); + if (type1_operands.size() != type2_operands.size()) { + return false; + } + + for (size_t operand = 2; operand < type1_operands.size(); ++operand) { + if (type1->word(operand) != type2->word(operand)) { + auto def1 = _.FindDef(type1->word(operand)); + auto def2 = _.FindDef(type2->word(operand)); + if (!AreLayoutCompatibleStructs(_, def1, def2)) { + return false; + } + } + } + return true; +} + +// Returns true if all decorations that affect the data layout of the struct +// (like Offset), are the same for the two types. |type1| and |type2| must be +// OpTypeStruct instructions. +bool HaveSameLayoutDecorations(ValidationState_t& _, const Instruction* type1, + const Instruction* type2) { + assert(type1->opcode() == SpvOpTypeStruct && + "type1 must be and OpTypeStruct instruction."); + assert(type2->opcode() == SpvOpTypeStruct && + "type2 must be and OpTypeStruct instruction."); + const std::vector& type1_decorations = + _.id_decorations(type1->id()); + const std::vector& type2_decorations = + _.id_decorations(type2->id()); + + // TODO: Will have to add other check for arrays an matricies if we want to + // handle them. + if (HasConflictingMemberOffsets(type1_decorations, type2_decorations)) { + return false; + } + + return true; +} + +bool HasConflictingMemberOffsets( + const std::vector& type1_decorations, + const std::vector& type2_decorations) { + { + // We are interested in conflicting decoration. If a decoration is in one + // list but not the other, then we will assume the code is correct. We are + // looking for things we know to be wrong. + // + // We do not have to traverse type2_decoration because, after traversing + // type1_decorations, anything new will not be found in + // type1_decoration. Therefore, it cannot lead to a conflict. + for (const Decoration& decoration : type1_decorations) { + switch (decoration.dec_type()) { + case SpvDecorationOffset: { + // Since these affect the layout of the struct, they must be present + // in both structs. + auto compare = [&decoration](const Decoration& rhs) { + if (rhs.dec_type() != SpvDecorationOffset) return false; + return decoration.struct_member_index() == + rhs.struct_member_index(); + }; + auto i = std::find_if(type2_decorations.begin(), + type2_decorations.end(), compare); + if (i != type2_decorations.end() && + decoration.params().front() != i->params().front()) { + return true; + } + } break; + default: + // This decoration does not affect the layout of the structure, so + // just moving on. + break; + } + } + } + return false; +} + +spv_result_t ValidateVariable(ValidationState_t& _, const Instruction& inst) { + auto result_type = _.FindDef(inst.type_id()); + if (!result_type || result_type->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpVariable Result Type '" << _.getIdName(inst.type_id()) + << "' is not a pointer type."; + } + + const auto initializer_index = 3; + if (initializer_index < inst.operands().size()) { + const auto initializer_id = inst.GetOperandAs(initializer_index); + const auto initializer = _.FindDef(initializer_id); + const auto storage_class_index = 2; + const auto is_module_scope_var = + initializer && (initializer->opcode() == SpvOpVariable) && + (initializer->GetOperandAs(storage_class_index) != + SpvStorageClassFunction); + const auto is_constant = + initializer && spvOpcodeIsConstant(initializer->opcode()); + if (!initializer || !(is_constant || is_module_scope_var)) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpVariable Initializer '" << _.getIdName(initializer_id) + << "' is not a constant or module-scope variable."; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateLoad(ValidationState_t& _, const Instruction& inst) { + const auto result_type = _.FindDef(inst.type_id()); + if (!result_type) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpLoad Result Type '" << _.getIdName(inst.type_id()) + << "' is not defined."; + } + + const bool uses_variable_pointers = + _.features().variable_pointers || + _.features().variable_pointers_storage_buffer; + const auto pointer_index = 2; + const auto pointer_id = inst.GetOperandAs(pointer_index); + const auto pointer = _.FindDef(pointer_id); + if (!pointer || + ((_.addressing_model() == SpvAddressingModelLogical) && + ((!uses_variable_pointers && + !spvOpcodeReturnsLogicalPointer(pointer->opcode())) || + (uses_variable_pointers && + !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpLoad Pointer '" << _.getIdName(pointer_id) + << "' is not a logical pointer."; + } + + const auto pointer_type = _.FindDef(pointer->type_id()); + if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpLoad type for pointer '" << _.getIdName(pointer_id) + << "' is not a pointer type."; + } + + const auto pointee_type = _.FindDef(pointer_type->GetOperandAs(2)); + if (!pointee_type || result_type->id() != pointee_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpLoad Result Type '" << _.getIdName(inst.type_id()) + << "' does not match Pointer '" << _.getIdName(pointer->id()) + << "'s type."; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateStore(ValidationState_t& _, const Instruction& inst) { + const bool uses_variable_pointer = + _.features().variable_pointers || + _.features().variable_pointers_storage_buffer; + const auto pointer_index = 0; + const auto pointer_id = inst.GetOperandAs(pointer_index); + const auto pointer = _.FindDef(pointer_id); + if (!pointer || + (_.addressing_model() == SpvAddressingModelLogical && + ((!uses_variable_pointer && + !spvOpcodeReturnsLogicalPointer(pointer->opcode())) || + (uses_variable_pointer && + !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpStore Pointer '" << _.getIdName(pointer_id) + << "' is not a logical pointer."; + } + const auto pointer_type = _.FindDef(pointer->type_id()); + if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpStore type for pointer '" << _.getIdName(pointer_id) + << "' is not a pointer type."; + } + const auto type_id = pointer_type->GetOperandAs(2); + const auto type = _.FindDef(type_id); + if (!type || SpvOpTypeVoid == type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpStore Pointer '" << _.getIdName(pointer_id) + << "'s type is void."; + } + + // validate storage class + { + uint32_t data_type; + uint32_t storage_class; + if (!_.GetPointerTypeInfo(pointer_type->id(), &data_type, &storage_class)) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpStore Pointer '" << _.getIdName(pointer_id) + << "' is not pointer type"; + } + + if (storage_class == SpvStorageClassUniformConstant || + storage_class == SpvStorageClassInput || + storage_class == SpvStorageClassPushConstant) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpStore Pointer '" << _.getIdName(pointer_id) + << "' storage class is read-only"; + } + } + + const auto object_index = 1; + const auto object_id = inst.GetOperandAs(object_index); + const auto object = _.FindDef(object_id); + if (!object || !object->type_id()) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpStore Object '" << _.getIdName(object_id) + << "' is not an object."; + } + const auto object_type = _.FindDef(object->type_id()); + if (!object_type || SpvOpTypeVoid == object_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpStore Object '" << _.getIdName(object_id) + << "'s type is void."; + } + + if (type->id() != object_type->id()) { + if (!_.options()->relax_struct_store || type->opcode() != SpvOpTypeStruct || + object_type->opcode() != SpvOpTypeStruct) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpStore Pointer '" << _.getIdName(pointer_id) + << "'s type does not match Object '" + << _.getIdName(object->id()) << "'s type."; + } + + // TODO: Check for layout compatible matricies and arrays as well. + if (!AreLayoutCompatibleStructs(_, type, object_type)) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "OpStore Pointer '" << _.getIdName(pointer_id) + << "'s layout does not match Object '" + << _.getIdName(object->id()) << "'s layout."; + } + } + return SPV_SUCCESS; +} + +spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction& inst) { + const auto target_index = 0; + const auto target_id = inst.GetOperandAs(target_index); + const auto target = _.FindDef(target_id); + if (!target) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "Target operand '" << _.getIdName(target_id) + << "' is not defined."; + } + + const auto source_index = 1; + const auto source_id = inst.GetOperandAs(source_index); + const auto source = _.FindDef(source_id); + if (!source) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "Source operand '" << _.getIdName(source_id) + << "' is not defined."; + } + + const auto target_pointer_type = _.FindDef(target->type_id()); + if (!target_pointer_type || + target_pointer_type->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "Target operand '" << _.getIdName(target_id) + << "' is not a pointer."; + } + + const auto source_pointer_type = _.FindDef(source->type_id()); + if (!source_pointer_type || + source_pointer_type->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "Source operand '" << _.getIdName(source_id) + << "' is not a pointer."; + } + + if (inst.opcode() == SpvOpCopyMemory) { + const auto target_type = + _.FindDef(target_pointer_type->GetOperandAs(2)); + if (!target_type || target_type->opcode() == SpvOpTypeVoid) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "Target operand '" << _.getIdName(target_id) + << "' cannot be a void pointer."; + } + + const auto source_type = + _.FindDef(source_pointer_type->GetOperandAs(2)); + if (!source_type || source_type->opcode() == SpvOpTypeVoid) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "Source operand '" << _.getIdName(source_id) + << "' cannot be a void pointer."; + } + + if (target_type->id() != source_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "Target '" << _.getIdName(source_id) + << "'s type does not match Source '" + << _.getIdName(source_type->id()) << "'s type."; + } + } else { + const auto size_id = inst.GetOperandAs(2); + const auto size = _.FindDef(size_id); + if (!size) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "Size operand '" << _.getIdName(size_id) + << "' is not defined."; + } + + const auto size_type = _.FindDef(size->type_id()); + if (!_.IsIntScalarType(size_type->id())) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "Size operand '" << _.getIdName(size_id) + << "' must be a scalar integer type."; + } + + bool is_zero = true; + switch (size->opcode()) { + case SpvOpConstantNull: + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "Size operand '" << _.getIdName(size_id) + << "' cannot be a constant zero."; + case SpvOpConstant: + if (size_type->word(3) == 1 && + size->word(size->words().size() - 1) & 0x80000000) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "Size operand '" << _.getIdName(size_id) + << "' cannot have the sign bit set to 1."; + } + for (size_t i = 3; is_zero && i < size->words().size(); ++i) { + is_zero &= (size->word(i) == 0); + } + if (is_zero) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "Size operand '" << _.getIdName(size_id) + << "' cannot be a constant zero."; + } + break; + default: + // Cannot infer any other opcodes. + break; + } + } + return SPV_SUCCESS; +} + +spv_result_t ValidateAccessChain(ValidationState_t& _, + const Instruction& inst) { + std::string instr_name = + "Op" + std::string(spvOpcodeString(static_cast(inst.opcode()))); + + // The result type must be OpTypePointer. + auto result_type = _.FindDef(inst.type_id()); + if (SpvOpTypePointer != result_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "The Result Type of " << instr_name << " '" + << _.getIdName(inst.id()) << "' must be OpTypePointer. Found Op" + << spvOpcodeString(static_cast(result_type->opcode())) << "."; + } + + // Result type is a pointer. Find out what it's pointing to. + // This will be used to make sure the indexing results in the same type. + // OpTypePointer word 3 is the type being pointed to. + const auto result_type_pointee = _.FindDef(result_type->word(3)); + + // Base must be a pointer, pointing to the base of a composite object. + const auto base_index = 2; + const auto base_id = inst.GetOperandAs(base_index); + const auto base = _.FindDef(base_id); + const auto base_type = _.FindDef(base->type_id()); + if (!base_type || SpvOpTypePointer != base_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "The Base '" << _.getIdName(base_id) << "' in " << instr_name + << " instruction must be a pointer."; + } + + // The result pointer storage class and base pointer storage class must match. + // Word 2 of OpTypePointer is the Storage Class. + auto result_type_storage_class = result_type->word(2); + auto base_type_storage_class = base_type->word(2); + if (result_type_storage_class != base_type_storage_class) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "The result pointer storage class and base " + "pointer storage class in " + << instr_name << " do not match."; + } + + // The type pointed to by OpTypePointer (word 3) must be a composite type. + auto type_pointee = _.FindDef(base_type->word(3)); + + // Check Universal Limit (SPIR-V Spec. Section 2.17). + // The number of indexes passed to OpAccessChain may not exceed 255 + // The instruction includes 4 words + N words (for N indexes) + size_t num_indexes = inst.words().size() - 4; + if (inst.opcode() == SpvOpPtrAccessChain || + inst.opcode() == SpvOpInBoundsPtrAccessChain) { + // In pointer access chains, the element operand is required, but not + // counted as an index. + --num_indexes; + } + const size_t num_indexes_limit = + _.options()->universal_limits_.max_access_chain_indexes; + if (num_indexes > num_indexes_limit) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "The number of indexes in " << instr_name << " may not exceed " + << num_indexes_limit << ". Found " << num_indexes << " indexes."; + } + // Indexes walk the type hierarchy to the desired depth, potentially down to + // scalar granularity. The first index in Indexes will select the top-level + // member/element/component/element of the base composite. All composite + // constituents use zero-based numbering, as described by their OpType... + // instruction. The second index will apply similarly to that result, and so + // on. Once any non-composite type is reached, there must be no remaining + // (unused) indexes. + auto starting_index = 4; + if (inst.opcode() == SpvOpPtrAccessChain || + inst.opcode() == SpvOpInBoundsPtrAccessChain) { + ++starting_index; + } + for (size_t i = starting_index; i < inst.words().size(); ++i) { + const uint32_t cur_word = inst.words()[i]; + // Earlier ID checks ensure that cur_word definition exists. + auto cur_word_instr = _.FindDef(cur_word); + // The index must be a scalar integer type (See OpAccessChain in the Spec.) + auto index_type = _.FindDef(cur_word_instr->type_id()); + if (!index_type || SpvOpTypeInt != index_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << "Indexes passed to " << instr_name + << " must be of type integer."; + } + switch (type_pointee->opcode()) { + case SpvOpTypeMatrix: + case SpvOpTypeVector: + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: { + // In OpTypeMatrix, OpTypeVector, OpTypeArray, and OpTypeRuntimeArray, + // word 2 is the Element Type. + type_pointee = _.FindDef(type_pointee->word(2)); + break; + } + case SpvOpTypeStruct: { + // In case of structures, there is an additional constraint on the + // index: the index must be an OpConstant. + if (SpvOpConstant != cur_word_instr->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) + << "The passed to " << instr_name + << " to index into a " + "structure must be an OpConstant."; + } + // Get the index value from the OpConstant (word 3 of OpConstant). + // OpConstant could be a signed integer. But it's okay to treat it as + // unsigned because a negative constant int would never be seen as + // correct as a struct offset, since structs can't have more than 2 + // billion members. + const uint32_t cur_index = cur_word_instr->word(3); + // The index points to the struct member we want, therefore, the index + // should be less than the number of struct members. + const uint32_t num_struct_members = + static_cast(type_pointee->words().size() - 2); + if (cur_index >= num_struct_members) { + return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) + << "Index is out of bounds: " << instr_name + << " can not find index " << cur_index + << " into the structure '" + << _.getIdName(type_pointee->id()) << "'. This structure has " + << num_struct_members << " members. Largest valid index is " + << num_struct_members - 1 << "."; + } + // Struct members IDs start at word 2 of OpTypeStruct. + auto structMemberId = type_pointee->word(cur_index + 2); + type_pointee = _.FindDef(structMemberId); + break; + } + default: { + // Give an error. reached non-composite type while indexes still remain. + return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) + << instr_name + << " reached non-composite type while indexes " + "still remain to be traversed."; + } + } + } + // At this point, we have fully walked down from the base using the indeces. + // The type being pointed to should be the same as the result type. + if (type_pointee->id() != result_type_pointee->id()) { + return _.diag(SPV_ERROR_INVALID_ID, &inst) + << instr_name << " result type (Op" + << spvOpcodeString(static_cast(result_type_pointee->opcode())) + << ") does not match the type that results from indexing into the " + "base " + " (Op" + << spvOpcodeString(static_cast(type_pointee->opcode())) + << ")."; + } + + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t ValidateMemoryInstructions(ValidationState_t& _, + const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpVariable: + if (auto error = ValidateVariable(_, *inst)) return error; + break; + case SpvOpLoad: + if (auto error = ValidateLoad(_, *inst)) return error; + break; + case SpvOpStore: + if (auto error = ValidateStore(_, *inst)) return error; + break; + case SpvOpCopyMemory: + case SpvOpCopyMemorySized: + if (auto error = ValidateCopyMemory(_, *inst)) return error; + break; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpPtrAccessChain: + case SpvOpInBoundsPtrAccessChain: + if (auto error = ValidateAccessChain(_, *inst)) return error; + break; + case SpvOpImageTexelPointer: + case SpvOpArrayLength: + case SpvOpGenericPtrMemSemantics: + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_mode_setting.cpp b/3rdparty/spirv-tools/source/val/validate_mode_setting.cpp new file mode 100644 index 000000000..60e7c0a1d --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_mode_setting.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#include "source/val/validate.h" + +#include + +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) { + const auto entry_point_id = inst->GetOperandAs(1); + auto entry_point = _.FindDef(entry_point_id); + if (!entry_point || SpvOpFunction != entry_point->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpEntryPoint Entry Point '" << _.getIdName(entry_point_id) + << "' is not a function."; + } + // don't check kernel function signatures + const SpvExecutionModel execution_model = + inst->GetOperandAs(0); + if (execution_model != SpvExecutionModelKernel) { + // TODO: Check the entry point signature is void main(void), may be subject + // to change + const auto entry_point_type_id = entry_point->GetOperandAs(3); + const auto entry_point_type = _.FindDef(entry_point_type_id); + if (!entry_point_type || 3 != entry_point_type->words().size()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpEntryPoint Entry Point '" << _.getIdName(entry_point_id) + << "'s function parameter count is not zero."; + } + } + + auto return_type = _.FindDef(entry_point->type_id()); + if (!return_type || SpvOpTypeVoid != return_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpEntryPoint Entry Point '" << _.getIdName(entry_point_id) + << "'s function return type is not void."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateExecutionMode(ValidationState_t& _, + const Instruction* inst) { + const auto entry_point_id = inst->GetOperandAs(0); + const auto found = std::find(_.entry_points().cbegin(), + _.entry_points().cend(), entry_point_id); + if (found == _.entry_points().cend()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpExecutionMode Entry Point '" + << _.getIdName(entry_point_id) + << "' is not the Entry Point " + "operand of an OpEntryPoint."; + } + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t ModeSettingPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpEntryPoint: + if (auto error = ValidateEntryPoint(_, inst)) return error; + break; + case SpvOpExecutionMode: + if (auto error = ValidateExecutionMode(_, inst)) return error; + break; + default: + break; + } + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_non_uniform.cpp b/3rdparty/spirv-tools/source/val/validate_non_uniform.cpp new file mode 100644 index 000000000..89e82c616 --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_non_uniform.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of barrier SPIR-V instructions. + +#include "source/val/validate.h" + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_constant.h" +#include "source/spirv_target_env.h" +#include "source/util/bitutils.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateExecutionScope(ValidationState_t& _, + const Instruction* inst, uint32_t scope) { + SpvOp opcode = inst->opcode(); + bool is_int32 = false, is_const_int32 = false; + uint32_t value = 0; + std::tie(is_int32, is_const_int32, value) = _.EvalInt32IfConst(scope); + + if (!is_int32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Execution Scope to be a 32-bit int"; + } + + if (!is_const_int32) { + return SPV_SUCCESS; + } + + if (spvIsVulkanEnv(_.context()->target_env) && + _.context()->target_env != SPV_ENV_VULKAN_1_0 && + value != SpvScopeSubgroup) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": in Vulkan environment Execution scope is limited to " + "Subgroup"; + } + + if (value != SpvScopeSubgroup && value != SpvScopeWorkgroup) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": Execution scope is limited to Subgroup or Workgroup"; + } + + return SPV_SUCCESS; +} + +} // namespace + +// Validates correctness of non-uniform group instructions. +spv_result_t NonUniformPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + + if (spvOpcodeIsNonUniformGroupOperation(opcode)) { + const uint32_t execution_scope = inst->word(3); + if (auto error = ValidateExecutionScope(_, inst, execution_scope)) { + return error; + } + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_primitives.cpp b/3rdparty/spirv-tools/source/val/validate_primitives.cpp similarity index 66% rename from 3rdparty/spirv-tools/source/validate_primitives.cpp rename to 3rdparty/spirv-tools/source/val/validate_primitives.cpp index de5bc2c30..7d11f2e7a 100644 --- a/3rdparty/spirv-tools/source/validate_primitives.cpp +++ b/3rdparty/spirv-tools/source/val/validate_primitives.cpp @@ -14,31 +14,32 @@ // Validates correctness of primitive SPIR-V instructions. -#include "validate.h" +#include "source/val/validate.h" #include -#include "diagnostic.h" -#include "opcode.h" -#include "val/instruction.h" -#include "val/validation_state.h" +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" -namespace libspirv { +namespace spvtools { +namespace val { // Validates correctness of primitive instructions. -spv_result_t PrimitivesPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); +spv_result_t PrimitivesPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); switch (opcode) { case SpvOpEmitVertex: case SpvOpEndPrimitive: case SpvOpEmitStreamVertex: case SpvOpEndStreamPrimitive: - _.current_function().RegisterExecutionModelLimitation( - SpvExecutionModelGeometry, - std::string(spvOpcodeString(opcode)) + - " instructions require Geometry execution model"); + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + SpvExecutionModelGeometry, + std::string(spvOpcodeString(opcode)) + + " instructions require Geometry execution model"); break; default: break; @@ -47,17 +48,17 @@ spv_result_t PrimitivesPass(ValidationState_t& _, switch (opcode) { case SpvOpEmitStreamVertex: case SpvOpEndStreamPrimitive: { - const uint32_t stream_id = inst->words[1]; + const uint32_t stream_id = inst->word(1); const uint32_t stream_type = _.GetTypeId(stream_id); if (!_.IsIntScalarType(stream_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Stream to be int scalar"; } const SpvOp stream_opcode = _.GetIdOpcode(stream_id); if (!spvOpcodeIsConstant(stream_opcode)) { - return _.diag(SPV_ERROR_INVALID_DATA) + return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) << ": expected Stream to be constant instruction"; } @@ -70,4 +71,5 @@ spv_result_t PrimitivesPass(ValidationState_t& _, return SPV_SUCCESS; } -} // namespace libspirv +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validate_type.cpp b/3rdparty/spirv-tools/source/val/validate_type.cpp new file mode 100644 index 000000000..b6942272e --- /dev/null +++ b/3rdparty/spirv-tools/source/val/validate_type.cpp @@ -0,0 +1,316 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Ensures type declarations are unique unless allowed by the specification. + +#include "source/val/validate.h" + +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// True if the integer constant is > 0. |const_words| are words of the +// constant-defining instruction (either OpConstant or +// OpSpecConstant). typeWords are the words of the constant's-type-defining +// OpTypeInt. +bool AboveZero(const std::vector& const_words, + const std::vector& type_words) { + const uint32_t width = type_words[2]; + const bool is_signed = type_words[3] > 0; + const uint32_t lo_word = const_words[3]; + if (width > 32) { + // The spec currently doesn't allow integers wider than 64 bits. + const uint32_t hi_word = const_words[4]; // Must exist, per spec. + if (is_signed && (hi_word >> 31)) return false; + return (lo_word | hi_word) > 0; + } else { + if (is_signed && (lo_word >> 31)) return false; + return lo_word > 0; + } +} + +// Validates that type declarations are unique, unless multiple declarations +// of the same data type are allowed by the specification. +// (see section 2.8 Types and Variables) +// Doesn't do anything if SPV_VAL_ignore_type_decl_unique was declared in the +// module. +spv_result_t ValidateUniqueness(ValidationState_t& _, const Instruction* inst) { + if (_.HasExtension(Extension::kSPV_VALIDATOR_ignore_type_decl_unique)) + return SPV_SUCCESS; + + const auto opcode = inst->opcode(); + if (opcode != SpvOpTypeArray && opcode != SpvOpTypeRuntimeArray && + opcode != SpvOpTypeStruct && opcode != SpvOpTypePointer && + !_.RegisterUniqueTypeDeclaration(inst)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Duplicate non-aggregate type declarations are not allowed. " + "Opcode: " + << spvOpcodeString(opcode) << " id: " << inst->id(); + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeVector(ValidationState_t& _, const Instruction* inst) { + const auto component_index = 1; + const auto component_id = inst->GetOperandAs(component_index); + const auto component_type = _.FindDef(component_id); + if (!component_type || !spvOpcodeIsScalarType(component_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeVector Component Type '" << _.getIdName(component_id) + << "' is not a scalar type."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeMatrix(ValidationState_t& _, const Instruction* inst) { + const auto column_type_index = 1; + const auto column_type_id = inst->GetOperandAs(column_type_index); + const auto column_type = _.FindDef(column_type_id); + if (!column_type || SpvOpTypeVector != column_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeMatrix Column Type '" << _.getIdName(column_type_id) + << "' is not a vector."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeArray(ValidationState_t& _, const Instruction* inst) { + const auto element_type_index = 1; + const auto element_type_id = inst->GetOperandAs(element_type_index); + const auto element_type = _.FindDef(element_type_id); + if (!element_type || !spvOpcodeGeneratesType(element_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeArray Element Type '" << _.getIdName(element_type_id) + << "' is not a type."; + } + const auto length_index = 2; + const auto length_id = inst->GetOperandAs(length_index); + const auto length = _.FindDef(length_id); + if (!length || !spvOpcodeIsConstant(length->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeArray Length '" << _.getIdName(length_id) + << "' is not a scalar constant type."; + } + + // NOTE: Check the initialiser value of the constant + const auto const_inst = length->words(); + const auto const_result_type_index = 1; + const auto const_result_type = _.FindDef(const_inst[const_result_type_index]); + if (!const_result_type || SpvOpTypeInt != const_result_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeArray Length '" << _.getIdName(length_id) + << "' is not a constant integer type."; + } + + switch (length->opcode()) { + case SpvOpSpecConstant: + case SpvOpConstant: + if (AboveZero(length->words(), const_result_type->words())) break; + // Else fall through! + case SpvOpConstantNull: { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeArray Length '" << _.getIdName(length_id) + << "' default value must be at least 1."; + } + case SpvOpSpecConstantOp: + // Assume it's OK, rather than try to evaluate the operation. + break; + default: + assert(0 && "bug in spvOpcodeIsConstant() or result type isn't int"); + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeRuntimeArray(ValidationState_t& _, + const Instruction* inst) { + const auto element_type_index = 1; + const auto element_id = inst->GetOperandAs(element_type_index); + const auto element_type = _.FindDef(element_id); + if (!element_type || !spvOpcodeGeneratesType(element_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeRuntimeArray Element Type '" + << _.getIdName(element_id) << "' is not a type."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeStruct(ValidationState_t& _, const Instruction* inst) { + const uint32_t struct_id = inst->GetOperandAs(0); + for (size_t member_type_index = 1; + member_type_index < inst->operands().size(); ++member_type_index) { + auto member_type_id = inst->GetOperandAs(member_type_index); + auto member_type = _.FindDef(member_type_id); + if (!member_type || !spvOpcodeGeneratesType(member_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeStruct Member Type '" << _.getIdName(member_type_id) + << "' is not a type."; + } + if (member_type->opcode() == SpvOpTypeVoid) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Structures cannot contain a void type."; + } + if (SpvOpTypeStruct == member_type->opcode() && + _.IsStructTypeWithBuiltInMember(member_type_id)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Structure " << _.getIdName(member_type_id) + << " contains members with BuiltIn decoration. Therefore this " + "structure may not be contained as a member of another " + "structure " + "type. Structure " + << _.getIdName(struct_id) << " contains structure " + << _.getIdName(member_type_id) << "."; + } + if (_.IsForwardPointer(member_type_id)) { + if (member_type->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Found a forward reference to a non-pointer " + "type in OpTypeStruct instruction."; + } + // If we're dealing with a forward pointer: + // Find out the type that the pointer is pointing to (must be struct) + // word 3 is the of the type being pointed to. + auto type_pointing_to = _.FindDef(member_type->words()[3]); + if (type_pointing_to && type_pointing_to->opcode() != SpvOpTypeStruct) { + // Forward declared operands of a struct may only point to a struct. + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "A forward reference operand in an OpTypeStruct must be an " + "OpTypePointer that points to an OpTypeStruct. " + "Found OpTypePointer that points to Op" + << spvOpcodeString( + static_cast(type_pointing_to->opcode())) + << "."; + } + } + } + std::unordered_set built_in_members; + for (auto decoration : _.id_decorations(struct_id)) { + if (decoration.dec_type() == SpvDecorationBuiltIn && + decoration.struct_member_index() != Decoration::kInvalidMember) { + built_in_members.insert(decoration.struct_member_index()); + } + } + int num_struct_members = static_cast(inst->operands().size() - 1); + int num_builtin_members = static_cast(built_in_members.size()); + if (num_builtin_members > 0 && num_builtin_members != num_struct_members) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "When BuiltIn decoration is applied to a structure-type member, " + "all members of that structure type must also be decorated with " + "BuiltIn (No allowed mixing of built-in variables and " + "non-built-in variables within a single structure). Structure id " + << struct_id << " does not meet this requirement."; + } + if (num_builtin_members > 0) { + _.RegisterStructTypeWithBuiltInMember(struct_id); + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTypePointer(ValidationState_t& _, + const Instruction* inst) { + const auto type_id = inst->GetOperandAs(2); + const auto type = _.FindDef(type_id); + if (!type || !spvOpcodeGeneratesType(type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypePointer Type '" << _.getIdName(type_id) + << "' is not a type."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeFunction(ValidationState_t& _, + const Instruction* inst) { + const auto return_type_id = inst->GetOperandAs(1); + const auto return_type = _.FindDef(return_type_id); + if (!return_type || !spvOpcodeGeneratesType(return_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeFunction Return Type '" << _.getIdName(return_type_id) + << "' is not a type."; + } + size_t num_args = 0; + for (size_t param_type_index = 2; param_type_index < inst->operands().size(); + ++param_type_index, ++num_args) { + const auto param_id = inst->GetOperandAs(param_type_index); + const auto param_type = _.FindDef(param_id); + if (!param_type || !spvOpcodeGeneratesType(param_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeFunction Parameter Type '" << _.getIdName(param_id) + << "' is not a type."; + } + } + const uint32_t num_function_args_limit = + _.options()->universal_limits_.max_function_args; + if (num_args > num_function_args_limit) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeFunction may not take more than " + << num_function_args_limit << " arguments. OpTypeFunction '" + << _.getIdName(inst->GetOperandAs(0)) << "' has " + << num_args << " arguments."; + } + + // The only valid uses of OpTypeFunction are in an OpFunction instruction. + for (auto& pair : inst->uses()) { + const auto* use = pair.first; + if (use->opcode() != SpvOpFunction) { + return _.diag(SPV_ERROR_INVALID_ID, use) + << "Invalid use of function type result id " + << _.getIdName(inst->id()) << "."; + } + } + + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) { + if (!spvOpcodeGeneratesType(inst->opcode())) return SPV_SUCCESS; + + if (auto error = ValidateUniqueness(_, inst)) return error; + + switch (inst->opcode()) { + case SpvOpTypeVector: + if (auto error = ValidateTypeVector(_, inst)) return error; + break; + case SpvOpTypeMatrix: + if (auto error = ValidateTypeMatrix(_, inst)) return error; + break; + case SpvOpTypeArray: + if (auto error = ValidateTypeArray(_, inst)) return error; + break; + case SpvOpTypeRuntimeArray: + if (auto error = ValidateTypeRuntimeArray(_, inst)) return error; + break; + case SpvOpTypeStruct: + if (auto error = ValidateTypeStruct(_, inst)) return error; + break; + case SpvOpTypePointer: + if (auto error = ValidateTypePointer(_, inst)) return error; + break; + case SpvOpTypeFunction: + if (auto error = ValidateTypeFunction(_, inst)) return error; + break; + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validation_state.cpp b/3rdparty/spirv-tools/source/val/validation_state.cpp index 50ebb6220..dec25b1c3 100644 --- a/3rdparty/spirv-tools/source/val/validation_state.cpp +++ b/3rdparty/spirv-tools/source/val/validation_state.cpp @@ -12,25 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "val/validation_state.h" +#include "source/val/validation_state.h" #include +#include +#include -#include "opcode.h" -#include "val/basic_block.h" -#include "val/construct.h" -#include "val/function.h" - -using std::deque; -using std::make_pair; -using std::pair; -using std::string; -using std::unordered_map; -using std::vector; - -namespace libspirv { +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/val/basic_block.h" +#include "source/val/construct.h" +#include "source/val/function.h" +#include "spirv-tools/libspirv.h" +namespace spvtools { +namespace val { namespace { + bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) { // See Section 2.4 bool out = false; @@ -41,7 +39,9 @@ bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) { case kLayoutExtInstImport: out = op == SpvOpExtInstImport; break; case kLayoutMemoryModel: out = op == SpvOpMemoryModel; break; case kLayoutEntryPoint: out = op == SpvOpEntryPoint; break; - case kLayoutExecutionMode: out = op == SpvOpExecutionMode; break; + case kLayoutExecutionMode: + out = op == SpvOpExecutionMode || op == SpvOpExecutionModeId; + break; case kLayoutDebug1: switch (op) { case SpvOpSourceContinued: @@ -111,6 +111,7 @@ bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) { case SpvOpMemoryModel: case SpvOpEntryPoint: case SpvOpExecutionMode: + case SpvOpExecutionModeId: case SpvOpSourceContinued: case SpvOpSource: case SpvOpSourceExtension: @@ -135,13 +136,26 @@ bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) { return out; } -} // anonymous namespace +// Counts the number of instructions and functions in the file. +spv_result_t CountInstructions(void* user_data, + const spv_parsed_instruction_t* inst) { + ValidationState_t& _ = *(reinterpret_cast(user_data)); + if (inst->opcode == SpvOpFunction) _.increment_total_functions(); + _.increment_total_instructions(); + + return SPV_SUCCESS; +} + +} // namespace ValidationState_t::ValidationState_t(const spv_const_context ctx, - const spv_const_validator_options opt) + const spv_const_validator_options opt, + const uint32_t* words, + const size_t num_words) : context_(ctx), options_(opt), - instruction_counter_(0), + words_(words), + num_words_(num_words), unresolved_forward_ids_{}, operand_names_{}, current_layout_section_(kLayoutCapabilities), @@ -154,10 +168,42 @@ ValidationState_t::ValidationState_t(const spv_const_context ctx, local_vars_(), struct_nesting_depth_(), grammar_(ctx), - addressing_model_(SpvAddressingModelLogical), - memory_model_(SpvMemoryModelSimple), + addressing_model_(SpvAddressingModelMax), + memory_model_(SpvMemoryModelMax), in_function_(false) { assert(opt && "Validator options may not be Null."); + + const auto env = context_->target_env; + + if (spvIsVulkanEnv(env)) { + // Vulkan 1.1 includes VK_KHR_relaxed_block_layout in core. + if (env != SPV_ENV_VULKAN_1_0) { + features_.env_relaxed_block_layout = true; + } + } + + switch (env) { + case SPV_ENV_WEBGPU_0: + features_.bans_op_undef = true; + break; + default: + break; + } + + // Only attempt to count if we have words, otherwise let the other validation + // fail and generate an error. + if (num_words > 0) { + // Count the number of instructions in the binary. + spvBinaryParse(ctx, this, words, num_words, + /* parsed_header = */ nullptr, CountInstructions, + /* diagnostic = */ nullptr); + preallocateStorage(); + } +} + +void ValidationState_t::preallocateStorage() { + ordered_instructions_.reserve(total_instructions_); + module_functions_.reserve(total_functions_); } spv_result_t ValidationState_t::ForwardDeclareId(uint32_t id) { @@ -179,11 +225,11 @@ bool ValidationState_t::IsForwardPointer(uint32_t id) const { return (forward_pointer_ids_.find(id) != forward_pointer_ids_.end()); } -void ValidationState_t::AssignNameToId(uint32_t id, string name) { +void ValidationState_t::AssignNameToId(uint32_t id, std::string name) { operand_names_[id] = name; } -string ValidationState_t::getIdName(uint32_t id) const { +std::string ValidationState_t::getIdName(uint32_t id) const { std::stringstream out; out << id; if (operand_names_.find(id) != end(operand_names_)) { @@ -192,9 +238,9 @@ string ValidationState_t::getIdName(uint32_t id) const { return out.str(); } -string ValidationState_t::getIdOrName(uint32_t id) const { +std::string ValidationState_t::getIdOrName(uint32_t id) const { std::stringstream out; - if (operand_names_.find(id) != end(operand_names_)) { + if (operand_names_.find(id) != std::end(operand_names_)) { out << operand_names_.at(id); } else { out << id; @@ -206,14 +252,14 @@ size_t ValidationState_t::unresolved_forward_id_count() const { return unresolved_forward_ids_.size(); } -vector ValidationState_t::UnresolvedForwardIds() const { - vector out(begin(unresolved_forward_ids_), - end(unresolved_forward_ids_)); +std::vector ValidationState_t::UnresolvedForwardIds() const { + std::vector out(std::begin(unresolved_forward_ids_), + std::end(unresolved_forward_ids_)); return out; } bool ValidationState_t::IsDefinedId(uint32_t id) const { - return all_definitions_.find(id) != end(all_definitions_); + return all_definitions_.find(id) != std::end(all_definitions_); } const Instruction* ValidationState_t::FindDef(uint32_t id) const { @@ -228,11 +274,6 @@ Instruction* ValidationState_t::FindDef(uint32_t id) { return it->second; } -// Increments the instruction count. Used for diagnostic -int ValidationState_t::increment_instruction_count() { - return instruction_counter_++; -} - ModuleLayoutSection ValidationState_t::current_layout_section() const { return current_layout_section_; } @@ -249,13 +290,18 @@ bool ValidationState_t::IsOpcodeInCurrentLayoutSection(SpvOp op) { return IsInstructionInLayoutSection(current_layout_section_, op); } -DiagnosticStream ValidationState_t::diag(spv_result_t error_code) const { - return libspirv::DiagnosticStream( - {0, 0, static_cast(instruction_counter_)}, context_->consumer, - error_code); +DiagnosticStream ValidationState_t::diag(spv_result_t error_code, + const Instruction* inst) const { + std::string disassembly; + if (inst) disassembly = Disassemble(*inst); + + return DiagnosticStream({0, 0, inst ? inst->LineNum() : 0}, + context_->consumer, disassembly, error_code); } -deque& ValidationState_t::functions() { return module_functions_; } +std::vector& ValidationState_t::functions() { + return module_functions_; +} Function& ValidationState_t::current_function() { assert(in_function_body()); @@ -273,6 +319,12 @@ const Function* ValidationState_t::function(uint32_t id) const { return it->second; } +Function* ValidationState_t::function(uint32_t id) { + auto it = id_to_function_.find(id); + if (it == id_to_function_.end()) return nullptr; + return it->second; +} + bool ValidationState_t::in_function_body() const { return in_function_; } bool ValidationState_t::in_block() const { @@ -298,6 +350,12 @@ void ValidationState_t::RegisterCapability(SpvCapability cap) { case SpvCapabilityKernel: features_.group_ops_reduce_and_scans = true; break; + case SpvCapabilityInt8: + case SpvCapabilityStorageBuffer8BitAccess: + case SpvCapabilityUniformAndStorageBuffer8BitAccess: + case SpvCapabilityStoragePushConstant8: + features_.declare_int8_type = true; + break; case SpvCapabilityInt16: features_.declare_int16_type = true; break; @@ -400,29 +458,54 @@ spv_result_t ValidationState_t::RegisterFunctionEnd() { return SPV_SUCCESS; } -void ValidationState_t::RegisterInstruction( - const spv_parsed_instruction_t& inst) { - if (in_function_body()) { - ordered_instructions_.emplace_back(&inst, ¤t_function(), - current_function().current_block()); - } else { - ordered_instructions_.emplace_back(&inst, nullptr, nullptr); - } - uint32_t id = ordered_instructions_.back().id(); - if (id) { - all_definitions_.insert(make_pair(id, &ordered_instructions_.back())); +Instruction* ValidationState_t::AddOrderedInstruction( + const spv_parsed_instruction_t* inst) { + ordered_instructions_.emplace_back(inst); + ordered_instructions_.back().SetLineNum(ordered_instructions_.size()); + return &ordered_instructions_.back(); +} + +// Improves diagnostic messages by collecting names of IDs +void ValidationState_t::RegisterDebugInstruction(const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpName: { + const auto target = inst->GetOperandAs(0); + const auto* str = reinterpret_cast(inst->words().data() + + inst->operand(1).offset); + AssignNameToId(target, str); + break; + } + case SpvOpMemberName: { + const auto target = inst->GetOperandAs(0); + const auto* str = reinterpret_cast(inst->words().data() + + inst->operand(2).offset); + AssignNameToId(target, str); + break; + } + case SpvOpSourceContinued: + case SpvOpSource: + case SpvOpSourceExtension: + case SpvOpString: + case SpvOpLine: + case SpvOpNoLine: + default: + break; } +} + +void ValidationState_t::RegisterInstruction(Instruction* inst) { + if (inst->id()) all_definitions_.insert(std::make_pair(inst->id(), inst)); // If the instruction is using an OpTypeSampledImage as an operand, it should // be recorded. The validator will ensure that all usages of an // OpTypeSampledImage and its definition are in the same basic block. - for (uint16_t i = 0; i < inst.num_operands; ++i) { - const spv_parsed_operand_t& operand = inst.operands[i]; + for (uint16_t i = 0; i < inst->operands().size(); ++i) { + const spv_parsed_operand_t& operand = inst->operand(i); if (SPV_OPERAND_TYPE_ID == operand.type) { - const uint32_t operand_word = inst.words[operand.offset]; + const uint32_t operand_word = inst->word(operand.offset); Instruction* operand_inst = FindDef(operand_word); if (operand_inst && SpvOpSampledImage == operand_inst->opcode()) { - RegisterSampledImageConsumer(operand_word, inst.result_id); + RegisterSampledImageConsumer(operand_word, inst->id()); } } } @@ -447,20 +530,20 @@ uint32_t ValidationState_t::getIdBound() const { return id_bound_; } void ValidationState_t::setIdBound(const uint32_t bound) { id_bound_ = bound; } -bool ValidationState_t::RegisterUniqueTypeDeclaration( - const spv_parsed_instruction_t& inst) { +bool ValidationState_t::RegisterUniqueTypeDeclaration(const Instruction* inst) { std::vector key; - key.push_back(static_cast(inst.opcode)); - for (int index = 0; index < inst.num_operands; ++index) { - const spv_parsed_operand_t& operand = inst.operands[index]; + key.push_back(static_cast(inst->opcode())); + for (size_t index = 0; index < inst->operands().size(); ++index) { + const spv_parsed_operand_t& operand = inst->operand(index); if (operand.type == SPV_OPERAND_TYPE_RESULT_ID) continue; const int words_begin = operand.offset; const int words_end = words_begin + operand.num_words; - assert(words_end <= static_cast(inst.num_words)); + assert(words_end <= static_cast(inst->words().size())); - key.insert(key.end(), inst.words + words_begin, inst.words + words_end); + key.insert(key.end(), inst->words().begin() + words_begin, + inst->words().begin() + words_end); } return unique_type_declarations_.insert(std::move(key)).second; @@ -744,12 +827,9 @@ bool ValidationState_t::GetPointerTypeInfo(uint32_t id, uint32_t* data_type, return true; } -uint32_t ValidationState_t::GetOperandTypeId( - const spv_parsed_instruction_t* inst, size_t operand_index) const { - assert(operand_index < inst->num_operands); - const spv_parsed_operand_t& operand = inst->operands[operand_index]; - assert(operand.num_words == 1); - return GetTypeId(inst->words[operand.offset]); +uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst, + size_t operand_index) const { + return GetTypeId(inst->GetOperandAs(operand_index)); } bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const { @@ -780,7 +860,7 @@ std::tuple ValidationState_t::EvalInt32IfConst( assert(inst); const uint32_t type = inst->type_id(); - if (!IsIntScalarType(type) || GetBitWidth(type) != 32) { + if (type == 0 || !IsIntScalarType(type) || GetBitWidth(type) != 32) { return std::make_tuple(false, false, 0); } @@ -792,4 +872,52 @@ std::tuple ValidationState_t::EvalInt32IfConst( return std::make_tuple(true, true, inst->word(3)); } -} // namespace libspirv +void ValidationState_t::ComputeFunctionToEntryPointMapping() { + for (const uint32_t entry_point : entry_points()) { + std::stack call_stack; + std::set visited; + call_stack.push(entry_point); + while (!call_stack.empty()) { + const uint32_t called_func_id = call_stack.top(); + call_stack.pop(); + if (!visited.insert(called_func_id).second) continue; + + function_to_entry_points_[called_func_id].push_back(entry_point); + + const Function* called_func = function(called_func_id); + if (called_func) { + // Other checks should error out on this invalid SPIR-V. + for (const uint32_t new_call : called_func->function_call_targets()) { + call_stack.push(new_call); + } + } + } + } +} + +const std::vector& ValidationState_t::FunctionEntryPoints( + uint32_t func) const { + auto iter = function_to_entry_points_.find(func); + if (iter == function_to_entry_points_.end()) { + return empty_ids_; + } else { + return iter->second; + } +} + +std::string ValidationState_t::Disassemble(const Instruction& inst) const { + const spv_parsed_instruction_t& c_inst(inst.c_inst()); + return Disassemble(c_inst.words, c_inst.num_words); +} + +std::string ValidationState_t::Disassemble(const uint32_t* words, + uint16_t num_words) const { + uint32_t disassembly_options = SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES; + + return spvInstructionBinaryToText(context()->target_env, words, num_words, + words_, num_words_, disassembly_options); +} + +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/val/validation_state.h b/3rdparty/spirv-tools/source/val/validation_state.h index 3e382c80e..8c8b53196 100644 --- a/3rdparty/spirv-tools/source/val/validation_state.h +++ b/3rdparty/spirv-tools/source/val/validation_state.h @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_VAL_VALIDATIONSTATE_H_ -#define LIBSPIRV_VAL_VALIDATIONSTATE_H_ +#ifndef SOURCE_VAL_VALIDATION_STATE_H_ +#define SOURCE_VAL_VALIDATION_STATE_H_ -#include +#include #include #include #include @@ -23,17 +23,20 @@ #include #include -#include "assembly_grammar.h" -#include "decoration.h" -#include "diagnostic.h" -#include "enum_set.h" -#include "latest_version_spirv_header.h" +#include "source/assembly_grammar.h" +#include "source/diagnostic.h" +#include "source/disassemble.h" +#include "source/enum_set.h" +#include "source/latest_version_spirv_header.h" +#include "source/spirv_definition.h" +#include "source/spirv_validator_options.h" +#include "source/val/decoration.h" +#include "source/val/function.h" +#include "source/val/instruction.h" #include "spirv-tools/libspirv.h" -#include "spirv_definition.h" -#include "val/function.h" -#include "val/instruction.h" -namespace libspirv { +namespace spvtools { +namespace val { /// This enum represents the sections of a SPIRV module. See section 2.4 /// of the SPIRV spec for additional details of the order. The enumerant values @@ -57,7 +60,7 @@ enum ModuleLayoutSection { /// This class manages the state of the SPIR-V validation as it is being parsed. class ValidationState_t { public: - // Features that can optionally be turned on by a capability. + // Features that can optionally be turned on by a capability or environment. struct Feature { bool declare_int16_type = false; // Allow OpTypeInt with 16 bit width? bool declare_float16_type = false; // Allow OpTypeFloat with 16 bit width? @@ -73,10 +76,21 @@ class ValidationState_t { // Permit group oerations Reduce, InclusiveScan, ExclusiveScan bool group_ops_reduce_and_scans = false; + + // Disallows the use of OpUndef + bool bans_op_undef = false; + + // Allow OpTypeInt with 8 bit width? + bool declare_int8_type = false; + + // Target environment uses relaxed block layout. + // This is true for Vulkan 1.1 or later. + bool env_relaxed_block_layout = false; }; ValidationState_t(const spv_const_context context, - const spv_const_validator_options opt); + const spv_const_validator_options opt, + const uint32_t* words, const size_t num_words); /// Returns the context spv_const_context context() const { return context_; } @@ -84,6 +98,18 @@ class ValidationState_t { /// Returns the command line options spv_const_validator_options options() const { return options_; } + /// Sets the ID of the generator for this module. + void setGenerator(uint32_t gen) { generator_ = gen; } + + /// Returns the ID of the generator for this module. + uint32_t generator() const { return generator_; } + + /// Sets the SPIR-V version of this module. + void setVersion(uint32_t ver) { version_ = ver; } + + /// Gets the SPIR-V version of this module. + uint32_t version() const { return version_; } + /// Forward declares the id in the module spv_result_t ForwardDeclareId(uint32_t id); @@ -123,8 +149,16 @@ class ValidationState_t { /// Returns true if the id has been defined bool IsDefinedId(uint32_t id) const; - /// Increments the instruction count. Used for diagnostic - int increment_instruction_count(); + /// Increments the total number of instructions in the file. + void increment_total_instructions() { total_instructions_++; } + + /// Increments the total number of functions in the file. + void increment_total_functions() { total_functions_++; } + + /// Allocates internal storage. Note, calling this will invalidate any + /// pointers to |ordered_instructions_| or |module_functions_| and, hence, + /// should only be called at the beginning of validation. + void preallocateStorage(); /// Returns the current layout section which is being processed ModuleLayoutSection current_layout_section() const; @@ -135,10 +169,10 @@ class ValidationState_t { /// Determines if the op instruction is part of the current section bool IsOpcodeInCurrentLayoutSection(SpvOp op); - libspirv::DiagnosticStream diag(spv_result_t error_code) const; + DiagnosticStream diag(spv_result_t error_code, const Instruction* inst) const; /// Returns the function states - std::deque& functions(); + std::vector& functions(); /// Returns the function states Function& current_function(); @@ -146,6 +180,7 @@ class ValidationState_t { /// Returns function state with the given id, or nullptr if no such function. const Function* function(uint32_t id) const; + Function* function(uint32_t id); /// Returns true if the called after a function instruction but before the /// function end instruction @@ -155,34 +190,32 @@ class ValidationState_t { /// instruction bool in_block() const; - /// Registers the given as an Entry Point with |execution_model|. - void RegisterEntryPointId(const uint32_t id, - SpvExecutionModel execution_model) { + struct EntryPointDescription { + std::string name; + std::vector interfaces; + }; + + /// Registers |id| as an entry point with |execution_model| and |interfaces|. + void RegisterEntryPoint(const uint32_t id, SpvExecutionModel execution_model, + EntryPointDescription&& desc) { entry_points_.push_back(id); - entry_point_interfaces_.emplace(id, std::vector()); entry_point_to_execution_models_[id].insert(execution_model); + entry_point_descriptions_[id].emplace_back(desc); } /// Returns a list of entry point function ids const std::vector& entry_points() const { return entry_points_; } - /// Adds a new interface id to the interfaces of the given entry point. - void RegisterInterfaceForEntryPoint(uint32_t entry_point, - uint32_t interface) { - entry_point_interfaces_[entry_point].push_back(interface); - } - /// Registers execution mode for the given entry point. void RegisterExecutionModeForEntryPoint(uint32_t entry_point, SpvExecutionMode execution_mode) { entry_point_to_execution_modes_[entry_point].insert(execution_mode); } - /// Returns the interfaces of a given entry point. If the given id is not a - /// valid Entry Point id, std::out_of_range exception is thrown. - const std::vector& entry_point_interfaces( - uint32_t entry_point) const { - return entry_point_interfaces_.at(entry_point); + /// Returns the interface descriptions of a given entry point. + const std::vector& entry_point_descriptions( + uint32_t entry_point) { + return entry_point_descriptions_.at(entry_point); } /// Returns Execution Models for the given Entry Point. @@ -208,6 +241,13 @@ class ValidationState_t { return &it->second; } + /// Traverses call tree and computes function_to_entry_points_. + /// Note: called after fully parsing the binary. + void ComputeFunctionToEntryPointMapping(); + + /// Returns all the entry points that can call |func|. + const std::vector& FunctionEntryPoints(uint32_t func) const; + /// Inserts an to the set of functions that are target of OpFunctionCall. void AddFunctionCallTarget(const uint32_t id) { function_call_targets_.insert(id); @@ -246,15 +286,21 @@ class ValidationState_t { /// Returns true if any of the capabilities is enabled, or if |capabilities| /// is an empty set. - bool HasAnyOfCapabilities(const libspirv::CapabilitySet& capabilities) const; + bool HasAnyOfCapabilities(const CapabilitySet& capabilities) const; /// Returns true if any of the extensions is enabled, or if |extensions| /// is an empty set. - bool HasAnyOfExtensions(const libspirv::ExtensionSet& extensions) const; + bool HasAnyOfExtensions(const ExtensionSet& extensions) const; /// Sets the addressing model of this module (logical/physical). void set_addressing_model(SpvAddressingModel am); + /// Returns true if the OpMemoryModel was found. + bool has_memory_model_specified() const { + return addressing_model_ != SpvAddressingModelMax && + memory_model_ != SpvMemoryModelMax; + } + /// Returns the addressing model of this module, or Logical if uninitialized. SpvAddressingModel addressing_model() const; @@ -266,8 +312,15 @@ class ValidationState_t { const AssemblyGrammar& grammar() const { return grammar_; } - /// Registers the instruction - void RegisterInstruction(const spv_parsed_instruction_t& inst); + /// Inserts the instruction into the list of ordered instructions in the file. + Instruction* AddOrderedInstruction(const spv_parsed_instruction_t* inst); + + /// Registers the instruction. This will add the instruction to the list of + /// definitions and register sampled image consumers. + void RegisterInstruction(Instruction* inst); + + /// Registers the debug instruction information. + void RegisterDebugInstruction(const Instruction* inst); /// Registers the decoration for the given void RegisterDecorationForId(uint32_t id, const Decoration& dec) { @@ -318,8 +371,8 @@ class ValidationState_t { /// nullptr Instruction* FindDef(uint32_t id); - /// Returns a deque of instructions in the order they appear in the binary - const std::deque& ordered_instructions() const { + /// Returns the instructions in the order they appear in the binary + const std::vector& ordered_instructions() const { return ordered_instructions_; } @@ -354,6 +407,12 @@ class ValidationState_t { /// Inserts a new to the set of Local Variables. void registerLocalVariable(const uint32_t id) { local_vars_.insert(id); } + // Returns true if using relaxed block layout, equivalent to + // VK_KHR_relaxed_block_layout. + bool IsRelaxedBlockLayout() const { + return features_.env_relaxed_block_layout || options()->relax_block_layout; + } + /// Sets the struct nesting depth for a given struct ID void set_struct_nesting_depth(uint32_t id, uint32_t depth) { struct_nesting_depth_[id] = depth; @@ -379,7 +438,7 @@ class ValidationState_t { /// Adds the instruction data to unique_type_declarations_. /// Returns false if an identical type declaration already exists. - bool RegisterUniqueTypeDeclaration(const spv_parsed_instruction_t& inst); + bool RegisterUniqueTypeDeclaration(const Instruction* inst); // Returns type_id of the scalar component of |id|. // |id| can be either @@ -445,7 +504,7 @@ class ValidationState_t { // Returns type_id for given id operand if it has a type or zero otherwise. // |operand_index| is expected to be pointing towards an operand which is an // id. - uint32_t GetOperandTypeId(const spv_parsed_instruction_t* inst, + uint32_t GetOperandTypeId(const Instruction* inst, size_t operand_index) const; // Provides information on pointer type. Returns false iff not pointer type. @@ -456,6 +515,12 @@ class ValidationState_t { // Returns tuple . std::tuple EvalInt32IfConst(uint32_t id); + // Returns the disassembly string for the given instruction. + std::string Disassemble(const Instruction& inst) const; + + // Returns the disassembly string for the given instruction. + std::string Disassemble(const uint32_t* words, uint16_t num_words) const; + private: ValidationState_t(const ValidationState_t&); @@ -464,8 +529,20 @@ class ValidationState_t { /// Stores the Validator command line options. Must be a valid options object. const spv_const_validator_options options_; - /// Tracks the number of instructions evaluated by the validator - int instruction_counter_; + /// The SPIR-V binary module we're validating. + const uint32_t* words_; + const size_t num_words_; + + /// The generator of the SPIR-V. + uint32_t generator_ = 0; + + /// The version of the SPIR-V. + uint32_t version_ = 0; + + /// The total number of instructions in the binary. + size_t total_instructions_ = 0; + /// The total number of functions in the binary. + size_t total_functions_ = 0; /// IDs which have been forward declared but have not been defined std::unordered_set unresolved_forward_ids_; @@ -486,18 +563,16 @@ class ValidationState_t { /// A list of functions in the module. /// Pointers to objects in this container are guaranteed to be stable and /// valid until the end of lifetime of the validation state. - std::deque module_functions_; + std::vector module_functions_; /// Capabilities declared in the module - libspirv::CapabilitySet module_capabilities_; + CapabilitySet module_capabilities_; /// Extensions declared in the module - libspirv::ExtensionSet module_extensions_; + ExtensionSet module_extensions_; /// List of all instructions in the order they appear in the binary - /// Pointers to objects in this container are guaranteed to be stable and - /// valid until the end of lifetime of the validation state. - std::deque ordered_instructions_; + std::vector ordered_instructions_; /// Instructions that can be referenced by Ids std::unordered_map all_definitions_; @@ -505,8 +580,9 @@ class ValidationState_t { /// IDs that are entry points, ie, arguments to OpEntryPoint. std::vector entry_points_; - /// Maps an entry point id to its interfaces. - std::unordered_map> entry_point_interfaces_; + /// Maps an entry point id to its desciptions. + std::unordered_map> + entry_point_descriptions_; /// Functions IDs that are target of OpFunctionCall. std::unordered_set function_call_targets_; @@ -544,7 +620,7 @@ class ValidationState_t { bool in_function_; /// The state of optional features. These are determined by capabilities - /// declared by the module. + /// declared by the module and the environment. Feature features_; /// Maps function ids to function stat objects. @@ -559,8 +635,14 @@ class ValidationState_t { /// Mapping entry point -> execution modes. std::unordered_map> entry_point_to_execution_modes_; + + /// Mapping function -> array of entry points inside this + /// module which can (indirectly) call the function. + std::unordered_map> function_to_entry_points_; + const std::vector empty_ids_; }; -} // namespace libspirv +} // namespace val +} // namespace spvtools -#endif /// LIBSPIRV_VAL_VALIDATIONSTATE_H_ +#endif // SOURCE_VAL_VALIDATION_STATE_H_ diff --git a/3rdparty/spirv-tools/source/validate.cpp b/3rdparty/spirv-tools/source/validate.cpp deleted file mode 100644 index 2c3386c9b..000000000 --- a/3rdparty/spirv-tools/source/validate.cpp +++ /dev/null @@ -1,430 +0,0 @@ -// Copyright (c) 2015-2016 The Khronos Group Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "validate.h" - -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "binary.h" -#include "diagnostic.h" -#include "enum_string_mapping.h" -#include "extensions.h" -#include "instruction.h" -#include "opcode.h" -#include "operand.h" -#include "spirv-tools/libspirv.h" -#include "spirv_constant.h" -#include "spirv_endian.h" -#include "spirv_target_env.h" -#include "spirv_validator_options.h" -#include "val/construct.h" -#include "val/function.h" -#include "val/validation_state.h" - -using std::function; -using std::ostream_iterator; -using std::string; -using std::stringstream; -using std::transform; -using std::vector; -using std::placeholders::_1; - -using libspirv::CfgPass; -using libspirv::DataRulesPass; -using libspirv::Extension; -using libspirv::IdPass; -using libspirv::InstructionPass; -using libspirv::LiteralsPass; -using libspirv::ModuleLayoutPass; -using libspirv::ValidationState_t; - -spv_result_t spvValidateIDs(const spv_instruction_t* pInsts, - const uint64_t count, - const ValidationState_t& state, - spv_position position) { - position->index = SPV_INDEX_INSTRUCTION; - if (auto error = spvValidateInstructionIDs(pInsts, count, state, position)) - return error; - return SPV_SUCCESS; -} - -namespace { - -// TODO(umar): Validate header -// TODO(umar): The binary parser validates the magic word, and the length of the -// header, but nothing else. -spv_result_t setHeader(void* user_data, spv_endianness_t endian, uint32_t magic, - uint32_t version, uint32_t generator, uint32_t id_bound, - uint32_t reserved) { - // Record the ID bound so that the validator can ensure no ID is out of bound. - ValidationState_t& _ = *(reinterpret_cast(user_data)); - _.setIdBound(id_bound); - - (void)endian; - (void)magic; - (void)version; - (void)generator; - (void)id_bound; - (void)reserved; - return SPV_SUCCESS; -} - -// Improves diagnostic messages by collecting names of IDs -// NOTE: This function returns void and is not involved in validation -void DebugInstructionPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - switch (inst->opcode) { - case SpvOpName: { - const uint32_t target = *(inst->words + inst->operands[0].offset); - const char* str = - reinterpret_cast(inst->words + inst->operands[1].offset); - _.AssignNameToId(target, str); - } break; - case SpvOpMemberName: { - const uint32_t target = *(inst->words + inst->operands[0].offset); - const char* str = - reinterpret_cast(inst->words + inst->operands[2].offset); - _.AssignNameToId(target, str); - } break; - case SpvOpSourceContinued: - case SpvOpSource: - case SpvOpSourceExtension: - case SpvOpString: - case SpvOpLine: - case SpvOpNoLine: - - default: - break; - } -} - -// Parses OpExtension instruction and registers extension. -void RegisterExtension(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const std::string extension_str = libspirv::GetExtensionString(inst); - Extension extension; - if (!GetExtensionFromString(extension_str.c_str(), &extension)) { - // The error will be logged in the ProcessInstruction pass. - return; - } - - _.RegisterExtension(extension); -} - -// Parses the beginning of the module searching for OpExtension instructions. -// Registers extensions if recognized. Returns SPV_REQUESTED_TERMINATION -// once an instruction which is not SpvOpCapability and SpvOpExtension is -// encountered. According to the SPIR-V spec extensions are declared after -// capabilities and before everything else. -spv_result_t ProcessExtensions(void* user_data, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - if (opcode == SpvOpCapability) return SPV_SUCCESS; - - if (opcode == SpvOpExtension) { - ValidationState_t& _ = *(reinterpret_cast(user_data)); - RegisterExtension(_, inst); - return SPV_SUCCESS; - } - - // OpExtension block is finished, requesting termination. - return SPV_REQUESTED_TERMINATION; -} - -spv_result_t ProcessInstruction(void* user_data, - const spv_parsed_instruction_t* inst) { - ValidationState_t& _ = *(reinterpret_cast(user_data)); - _.increment_instruction_count(); - if (static_cast(inst->opcode) == SpvOpEntryPoint) { - const auto entry_point = inst->words[2]; - const SpvExecutionModel execution_model = SpvExecutionModel(inst->words[1]); - _.RegisterEntryPointId(entry_point, execution_model); - // Operand 3 and later are the of interfaces for the entry point. - for (int i = 3; i < inst->num_operands; ++i) { - _.RegisterInterfaceForEntryPoint(entry_point, - inst->words[inst->operands[i].offset]); - } - } - if (static_cast(inst->opcode) == SpvOpFunctionCall) { - _.AddFunctionCallTarget(inst->words[3]); - } - - DebugInstructionPass(_, inst); - if (auto error = CapabilityPass(_, inst)) return error; - if (auto error = DataRulesPass(_, inst)) return error; - if (auto error = IdPass(_, inst)) return error; - if (auto error = ModuleLayoutPass(_, inst)) return error; - if (auto error = CfgPass(_, inst)) return error; - if (auto error = InstructionPass(_, inst)) return error; - if (auto error = TypeUniquePass(_, inst)) return error; - if (auto error = ArithmeticsPass(_, inst)) return error; - if (auto error = CompositesPass(_, inst)) return error; - if (auto error = ConversionPass(_, inst)) return error; - if (auto error = DerivativesPass(_, inst)) return error; - if (auto error = LogicalsPass(_, inst)) return error; - if (auto error = BitwisePass(_, inst)) return error; - if (auto error = ExtInstPass(_, inst)) return error; - if (auto error = ImagePass(_, inst)) return error; - if (auto error = AtomicsPass(_, inst)) return error; - if (auto error = BarriersPass(_, inst)) return error; - if (auto error = PrimitivesPass(_, inst)) return error; - if (auto error = LiteralsPass(_, inst)) return error; - - return SPV_SUCCESS; -} - -void printDot(const ValidationState_t& _, const libspirv::BasicBlock& other) { - string block_string; - if (other.successors()->empty()) { - block_string += "end "; - } else { - for (auto block : *other.successors()) { - block_string += _.getIdOrName(block->id()) + " "; - } - } - printf("%10s -> {%s\b}\n", _.getIdOrName(other.id()).c_str(), - block_string.c_str()); -} - -void PrintBlocks(ValidationState_t& _, libspirv::Function func) { - assert(func.first_block()); - - printf("%10s -> %s\n", _.getIdOrName(func.id()).c_str(), - _.getIdOrName(func.first_block()->id()).c_str()); - for (const auto& block : func.ordered_blocks()) { - printDot(_, *block); - } -} - -#ifdef __clang__ -#define UNUSED(func) [[gnu::unused]] func -#elif defined(__GNUC__) -#define UNUSED(func) \ - func __attribute__((unused)); \ - func -#elif defined(_MSC_VER) -#define UNUSED(func) func -#endif - -UNUSED(void PrintDotGraph(ValidationState_t& _, libspirv::Function func)) { - if (func.first_block()) { - string func_name(_.getIdOrName(func.id())); - printf("digraph %s {\n", func_name.c_str()); - PrintBlocks(_, func); - printf("}\n"); - } -} - -spv_result_t ValidateBinaryUsingContextAndValidationState( - const spv_context_t& context, const uint32_t* words, const size_t num_words, - spv_diagnostic* pDiagnostic, ValidationState_t* vstate) { - auto binary = std::unique_ptr( - new spv_const_binary_t{words, num_words}); - - spv_endianness_t endian; - spv_position_t position = {}; - if (spvBinaryEndianness(binary.get(), &endian)) { - return libspirv::DiagnosticStream(position, context.consumer, - SPV_ERROR_INVALID_BINARY) - << "Invalid SPIR-V magic number."; - } - - spv_header_t header; - if (spvBinaryHeaderGet(binary.get(), endian, &header)) { - return libspirv::DiagnosticStream(position, context.consumer, - SPV_ERROR_INVALID_BINARY) - << "Invalid SPIR-V header."; - } - - if (header.version > spvVersionForTargetEnv(context.target_env)) { - return libspirv::DiagnosticStream(position, context.consumer, - SPV_ERROR_WRONG_VERSION) - << "Invalid SPIR-V binary version " - << SPV_SPIRV_VERSION_MAJOR_PART(header.version) << "." - << SPV_SPIRV_VERSION_MINOR_PART(header.version) - << " for target environment " - << spvTargetEnvDescription(context.target_env) << "."; - } - - // Look for OpExtension instructions and register extensions. - // Diagnostics if any will be produced in the next pass (ProcessInstruction). - spvBinaryParse(&context, vstate, words, num_words, - /* parsed_header = */ nullptr, ProcessExtensions, - /* diagnostic = */ nullptr); - - // NOTE: Parse the module and perform inline validation checks. These - // checks do not require the the knowledge of the whole module. - if (auto error = spvBinaryParse(&context, vstate, words, num_words, setHeader, - ProcessInstruction, pDiagnostic)) - return error; - - if (vstate->in_function_body()) - return vstate->diag(SPV_ERROR_INVALID_LAYOUT) - << "Missing OpFunctionEnd at end of module."; - - // TODO(umar): Add validation checks which require the parsing of the entire - // module. Use the information from the ProcessInstruction pass to make the - // checks. - if (vstate->unresolved_forward_id_count() > 0) { - stringstream ss; - vector ids = vstate->UnresolvedForwardIds(); - - transform(begin(ids), end(ids), ostream_iterator(ss, " "), - bind(&ValidationState_t::getIdName, std::ref(*vstate), _1)); - - auto id_str = ss.str(); - return vstate->diag(SPV_ERROR_INVALID_ID) - << "The following forward referenced IDs have not been defined:\n" - << id_str.substr(0, id_str.size() - 1); - } - - // Validate the preconditions involving adjacent instructions. e.g. SpvOpPhi - // must only be preceeded by SpvOpLabel, SpvOpPhi, or SpvOpLine. - if (auto error = ValidateAdjacency(*vstate)) return error; - - // CFG checks are performed after the binary has been parsed - // and the CFGPass has collected information about the control flow - if (auto error = PerformCfgChecks(*vstate)) return error; - if (auto error = UpdateIdUse(*vstate)) return error; - if (auto error = CheckIdDefinitionDominateUse(*vstate)) return error; - if (auto error = ValidateDecorations(*vstate)) return error; - - // Entry point validation. Based on 2.16.1 (Universal Validation Rules) of the - // SPIRV spec: - // * There is at least one OpEntryPoint instruction, unless the Linkage - // capability is being used. - // * No function can be targeted by both an OpEntryPoint instruction and an - // OpFunctionCall instruction. - if (vstate->entry_points().empty() && - !vstate->HasCapability(SpvCapabilityLinkage)) { - return vstate->diag(SPV_ERROR_INVALID_BINARY) - << "No OpEntryPoint instruction was found. This is only allowed if " - "the Linkage capability is being used."; - } - for (const auto& entry_point : vstate->entry_points()) { - if (vstate->IsFunctionCallTarget(entry_point)) { - return vstate->diag(SPV_ERROR_INVALID_BINARY) - << "A function (" << entry_point - << ") may not be targeted by both an OpEntryPoint instruction and " - "an OpFunctionCall instruction."; - } - } - - // NOTE: Copy each instruction for easier processing - std::vector instructions; - // Expect average instruction length to be a bit over 2 words. - instructions.reserve(binary->wordCount / 2); - uint64_t index = SPV_INDEX_INSTRUCTION; - while (index < binary->wordCount) { - uint16_t wordCount; - uint16_t opcode; - spvOpcodeSplit(spvFixWord(binary->code[index], endian), &wordCount, - &opcode); - spv_instruction_t inst; - spvInstructionCopy(&binary->code[index], static_cast(opcode), - wordCount, endian, &inst); - instructions.emplace_back(std::move(inst)); - index += wordCount; - } - - position.index = SPV_INDEX_INSTRUCTION; - if (auto error = spvValidateIDs(instructions.data(), instructions.size(), - *vstate, &position)) - return error; - - if (auto error = ValidateBuiltIns(*vstate)) return error; - - return SPV_SUCCESS; -} -} // anonymous namespace - -spv_result_t spvValidate(const spv_const_context context, - const spv_const_binary binary, - spv_diagnostic* pDiagnostic) { - return spvValidateBinary(context, binary->code, binary->wordCount, - pDiagnostic); -} - -spv_result_t spvValidateBinary(const spv_const_context context, - const uint32_t* words, const size_t num_words, - spv_diagnostic* pDiagnostic) { - spv_context_t hijack_context = *context; - if (pDiagnostic) { - *pDiagnostic = nullptr; - libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); - } - - // This interface is used for default command line options. - spv_validator_options default_options = spvValidatorOptionsCreate(); - - // Create the ValidationState using the context and default options. - ValidationState_t vstate(&hijack_context, default_options); - - spv_result_t result = ValidateBinaryUsingContextAndValidationState( - hijack_context, words, num_words, pDiagnostic, &vstate); - - spvValidatorOptionsDestroy(default_options); - return result; -} - -spv_result_t spvValidateWithOptions(const spv_const_context context, - spv_const_validator_options options, - const spv_const_binary binary, - spv_diagnostic* pDiagnostic) { - spv_context_t hijack_context = *context; - if (pDiagnostic) { - *pDiagnostic = nullptr; - libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); - } - - // Create the ValidationState using the context. - ValidationState_t vstate(&hijack_context, options); - - return ValidateBinaryUsingContextAndValidationState( - hijack_context, binary->code, binary->wordCount, pDiagnostic, &vstate); -} - -namespace spvtools { - -spv_result_t ValidateBinaryAndKeepValidationState( - const spv_const_context context, spv_const_validator_options options, - const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic, - std::unique_ptr* vstate) { - spv_context_t hijack_context = *context; - if (pDiagnostic) { - *pDiagnostic = nullptr; - libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); - } - - vstate->reset(new ValidationState_t(&hijack_context, options)); - - return ValidateBinaryUsingContextAndValidationState( - hijack_context, words, num_words, pDiagnostic, vstate->get()); -} - -spv_result_t ValidateInstructionAndUpdateValidationState( - ValidationState_t* vstate, const spv_parsed_instruction_t* inst) { - return ProcessInstruction(vstate, inst); -} - -} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/validate_adjacency.cpp b/3rdparty/spirv-tools/source/validate_adjacency.cpp deleted file mode 100644 index 75cea5294..000000000 --- a/3rdparty/spirv-tools/source/validate_adjacency.cpp +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) 2018 LunarG Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Validates correctness of the intra-block preconditions of SPIR-V -// instructions. - -#include "validate.h" - -#include - -#include "diagnostic.h" -#include "opcode.h" -#include "val/instruction.h" -#include "val/validation_state.h" - -namespace libspirv { - -spv_result_t ValidateAdjacency(ValidationState_t& _) { - const auto& instructions = _.ordered_instructions(); - for (auto i = instructions.cbegin(); i != instructions.cend(); ++i) { - switch (i->opcode()) { - case SpvOpPhi: - if (i != instructions.cbegin()) { - switch (prev(i)->opcode()) { - case SpvOpLabel: - case SpvOpPhi: - case SpvOpLine: - break; - default: - return _.diag(SPV_ERROR_INVALID_DATA) - << "OpPhi must appear before all non-OpPhi instructions " - << "(except for OpLine, which can be mixed with OpPhi)."; - } - } - break; - case SpvOpLoopMerge: - if (next(i) != instructions.cend()) { - switch (next(i)->opcode()) { - case SpvOpBranch: - case SpvOpBranchConditional: - break; - default: - return _.diag(SPV_ERROR_INVALID_DATA) - << "OpLoopMerge must immediately precede either an " - << "OpBranch or OpBranchConditional instruction. " - << "OpLoopMerge must be the second-to-last instruction in " - << "its block."; - } - } - break; - case SpvOpSelectionMerge: - if (next(i) != instructions.cend()) { - switch (next(i)->opcode()) { - case SpvOpBranchConditional: - case SpvOpSwitch: - break; - default: - return _.diag(SPV_ERROR_INVALID_DATA) - << "OpSelectionMerge must immediately precede either an " - << "OpBranchConditional or OpSwitch instruction. " - << "OpSelectionMerge must be the second-to-last " - << "instruction in its block."; - } - } - default: - break; - } - } - - return SPV_SUCCESS; -} - -} // namespace libspirv diff --git a/3rdparty/spirv-tools/source/validate_cfg.cpp b/3rdparty/spirv-tools/source/validate_cfg.cpp deleted file mode 100644 index d601e12f7..000000000 --- a/3rdparty/spirv-tools/source/validate_cfg.cpp +++ /dev/null @@ -1,430 +0,0 @@ -// Copyright (c) 2015-2016 The Khronos Group Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "cfa.h" -#include "validate.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "spirv_validator_options.h" -#include "val/basic_block.h" -#include "val/construct.h" -#include "val/function.h" -#include "val/validation_state.h" - -using std::find; -using std::function; -using std::get; -using std::ignore; -using std::make_pair; -using std::make_tuple; -using std::numeric_limits; -using std::pair; -using std::string; -using std::tie; -using std::transform; -using std::tuple; -using std::unordered_map; -using std::unordered_set; -using std::vector; - -using libspirv::BasicBlock; - -namespace libspirv { - -namespace { - -using bb_ptr = BasicBlock*; -using cbb_ptr = const BasicBlock*; -using bb_iter = vector::const_iterator; - -} // namespace - -void printDominatorList(const BasicBlock& b) { - std::cout << b.id() << " is dominated by: "; - const BasicBlock* bb = &b; - while (bb->immediate_dominator() != bb) { - bb = bb->immediate_dominator(); - std::cout << bb->id() << " "; - } -} - -#define CFG_ASSERT(ASSERT_FUNC, TARGET) \ - if (spv_result_t rcode = ASSERT_FUNC(_, TARGET)) return rcode - -spv_result_t FirstBlockAssert(ValidationState_t& _, uint32_t target) { - if (_.current_function().IsFirstBlock(target)) { - return _.diag(SPV_ERROR_INVALID_CFG) - << "First block " << _.getIdName(target) << " of function " - << _.getIdName(_.current_function().id()) << " is targeted by block " - << _.getIdName(_.current_function().current_block()->id()); - } - return SPV_SUCCESS; -} - -spv_result_t MergeBlockAssert(ValidationState_t& _, uint32_t merge_block) { - if (_.current_function().IsBlockType(merge_block, kBlockTypeMerge)) { - return _.diag(SPV_ERROR_INVALID_CFG) - << "Block " << _.getIdName(merge_block) - << " is already a merge block for another header"; - } - return SPV_SUCCESS; -} - -/// Update the continue construct's exit blocks once the backedge blocks are -/// identified in the CFG. -void UpdateContinueConstructExitBlocks( - Function& function, const vector>& back_edges) { - auto& constructs = function.constructs(); - // TODO(umar): Think of a faster way to do this - for (auto& edge : back_edges) { - uint32_t back_edge_block_id; - uint32_t loop_header_block_id; - tie(back_edge_block_id, loop_header_block_id) = edge; - auto is_this_header = [=](Construct& c) { - return c.type() == ConstructType::kLoop && - c.entry_block()->id() == loop_header_block_id; - }; - - for (auto construct : constructs) { - if (is_this_header(construct)) { - Construct* continue_construct = - construct.corresponding_constructs().back(); - assert(continue_construct->type() == ConstructType::kContinue); - - BasicBlock* back_edge_block; - tie(back_edge_block, ignore) = function.GetBlock(back_edge_block_id); - continue_construct->set_exit(back_edge_block); - } - } - } -} - -tuple ConstructNames(ConstructType type) { - string construct_name, header_name, exit_name; - - switch (type) { - case ConstructType::kSelection: - construct_name = "selection"; - header_name = "selection header"; - exit_name = "merge block"; - break; - case ConstructType::kLoop: - construct_name = "loop"; - header_name = "loop header"; - exit_name = "merge block"; - break; - case ConstructType::kContinue: - construct_name = "continue"; - header_name = "continue target"; - exit_name = "back-edge block"; - break; - case ConstructType::kCase: - construct_name = "case"; - header_name = "case entry block"; - exit_name = "case exit block"; - break; - default: - assert(1 == 0 && "Not defined type"); - } - - return make_tuple(construct_name, header_name, exit_name); -} - -/// Constructs an error message for construct validation errors -string ConstructErrorString(const Construct& construct, - const string& header_string, - const string& exit_string, - const string& dominate_text) { - string construct_name, header_name, exit_name; - tie(construct_name, header_name, exit_name) = - ConstructNames(construct.type()); - - // TODO(umar): Add header block for continue constructs to error message - return "The " + construct_name + " construct with the " + header_name + " " + - header_string + " " + dominate_text + " the " + exit_name + " " + - exit_string; -} - -spv_result_t StructuredControlFlowChecks( - const ValidationState_t& _, const Function& function, - const vector>& back_edges) { - /// Check all backedges target only loop headers and have exactly one - /// back-edge branching to it - - // Map a loop header to blocks with back-edges to the loop header. - std::map> loop_latch_blocks; - for (auto back_edge : back_edges) { - uint32_t back_edge_block; - uint32_t header_block; - tie(back_edge_block, header_block) = back_edge; - if (!function.IsBlockType(header_block, kBlockTypeLoop)) { - return _.diag(SPV_ERROR_INVALID_CFG) - << "Back-edges (" << _.getIdName(back_edge_block) << " -> " - << _.getIdName(header_block) - << ") can only be formed between a block and a loop header."; - } - loop_latch_blocks[header_block].insert(back_edge_block); - } - - // Check the loop headers have exactly one back-edge branching to it - for (BasicBlock* loop_header : function.ordered_blocks()) { - if (!loop_header->reachable()) continue; - if (!loop_header->is_type(kBlockTypeLoop)) continue; - auto loop_header_id = loop_header->id(); - auto num_latch_blocks = loop_latch_blocks[loop_header_id].size(); - if (num_latch_blocks != 1) { - return _.diag(SPV_ERROR_INVALID_CFG) - << "Loop header " << _.getIdName(loop_header_id) - << " is targeted by " << num_latch_blocks - << " back-edge blocks but the standard requires exactly one"; - } - } - - // Check construct rules - for (const Construct& construct : function.constructs()) { - auto header = construct.entry_block(); - auto merge = construct.exit_block(); - - if (header->reachable() && !merge) { - string construct_name, header_name, exit_name; - tie(construct_name, header_name, exit_name) = - ConstructNames(construct.type()); - return _.diag(SPV_ERROR_INTERNAL) - << "Construct " + construct_name + " with " + header_name + " " + - _.getIdName(header->id()) + " does not have a " + - exit_name + ". This may be a bug in the validator."; - } - - // If the exit block is reachable then it's dominated by the - // header. - if (merge && merge->reachable()) { - if (!header->dominates(*merge)) { - return _.diag(SPV_ERROR_INVALID_CFG) << ConstructErrorString( - construct, _.getIdName(header->id()), - _.getIdName(merge->id()), "does not dominate"); - } - // If it's really a merge block for a selection or loop, then it must be - // *strictly* dominated by the header. - if (construct.ExitBlockIsMergeBlock() && (header == merge)) { - return _.diag(SPV_ERROR_INVALID_CFG) << ConstructErrorString( - construct, _.getIdName(header->id()), - _.getIdName(merge->id()), "does not strictly dominate"); - } - } - // Check post-dominance for continue constructs. But dominance and - // post-dominance only make sense when the construct is reachable. - if (header->reachable() && construct.type() == ConstructType::kContinue) { - if (!merge->postdominates(*header)) { - return _.diag(SPV_ERROR_INVALID_CFG) << ConstructErrorString( - construct, _.getIdName(header->id()), - _.getIdName(merge->id()), "is not post dominated by"); - } - } - // TODO(umar): an OpSwitch block dominates all its defined case - // constructs - // TODO(umar): each case construct has at most one branch to another - // case construct - // TODO(umar): each case construct is branched to by at most one other - // case construct - // TODO(umar): if Target T1 branches to Target T2, or if Target T1 - // branches to the Default and the Default branches to Target T2, then - // T1 must immediately precede T2 in the list of the OpSwitch Target - // operands - } - return SPV_SUCCESS; -} - -spv_result_t PerformCfgChecks(ValidationState_t& _) { - for (auto& function : _.functions()) { - // Check all referenced blocks are defined within a function - if (function.undefined_block_count() != 0) { - string undef_blocks("{"); - for (auto undefined_block : function.undefined_blocks()) { - undef_blocks += _.getIdName(undefined_block) + " "; - } - return _.diag(SPV_ERROR_INVALID_CFG) - << "Block(s) " << undef_blocks << "\b}" - << " are referenced but not defined in function " - << _.getIdName(function.id()); - } - - // Set each block's immediate dominator and immediate postdominator, - // and find all back-edges. - // - // We want to analyze all the blocks in the function, even in degenerate - // control flow cases including unreachable blocks. So use the augmented - // CFG to ensure we cover all the blocks. - vector postorder; - vector postdom_postorder; - vector> back_edges; - auto ignore_block = [](cbb_ptr) {}; - auto ignore_edge = [](cbb_ptr, cbb_ptr) {}; - if (!function.ordered_blocks().empty()) { - /// calculate dominators - spvtools::CFA::DepthFirstTraversal( - function.first_block(), function.AugmentedCFGSuccessorsFunction(), - ignore_block, [&](cbb_ptr b) { postorder.push_back(b); }, - ignore_edge); - auto edges = spvtools::CFA::CalculateDominators( - postorder, function.AugmentedCFGPredecessorsFunction()); - for (auto edge : edges) { - edge.first->SetImmediateDominator(edge.second); - } - - /// calculate post dominators - spvtools::CFA::DepthFirstTraversal( - function.pseudo_exit_block(), - function.AugmentedCFGPredecessorsFunction(), ignore_block, - [&](cbb_ptr b) { postdom_postorder.push_back(b); }, ignore_edge); - auto postdom_edges = - spvtools::CFA::CalculateDominators( - postdom_postorder, function.AugmentedCFGSuccessorsFunction()); - for (auto edge : postdom_edges) { - edge.first->SetImmediatePostDominator(edge.second); - } - /// calculate back edges. - spvtools::CFA::DepthFirstTraversal( - function.pseudo_entry_block(), - function - .AugmentedCFGSuccessorsFunctionIncludingHeaderToContinueEdge(), - ignore_block, ignore_block, [&](cbb_ptr from, cbb_ptr to) { - back_edges.emplace_back(from->id(), to->id()); - }); - } - UpdateContinueConstructExitBlocks(function, back_edges); - - auto& blocks = function.ordered_blocks(); - if (!blocks.empty()) { - // Check if the order of blocks in the binary appear before the blocks - // they dominate - for (auto block = begin(blocks) + 1; block != end(blocks); ++block) { - if (auto idom = (*block)->immediate_dominator()) { - if (idom != function.pseudo_entry_block() && - block == std::find(begin(blocks), block, idom)) { - return _.diag(SPV_ERROR_INVALID_CFG) - << "Block " << _.getIdName((*block)->id()) - << " appears in the binary before its dominator " - << _.getIdName(idom->id()); - } - } - } - // If we have structed control flow, check that no block has a control - // flow nesting depth larger than the limit. - if (_.HasCapability(SpvCapabilityShader)) { - const int control_flow_nesting_depth_limit = - _.options()->universal_limits_.max_control_flow_nesting_depth; - for (auto block = begin(blocks); block != end(blocks); ++block) { - if (function.GetBlockDepth(*block) > - control_flow_nesting_depth_limit) { - return _.diag(SPV_ERROR_INVALID_CFG) - << "Maximum Control Flow nesting depth exceeded."; - } - } - } - } - - /// Structured control flow checks are only required for shader capabilities - if (_.HasCapability(SpvCapabilityShader)) { - if (auto error = StructuredControlFlowChecks(_, function, back_edges)) - return error; - } - } - return SPV_SUCCESS; -} - -spv_result_t CfgPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - SpvOp opcode = static_cast(inst->opcode); - switch (opcode) { - case SpvOpLabel: - if (auto error = _.current_function().RegisterBlock(inst->result_id)) - return error; - break; - case SpvOpLoopMerge: { - uint32_t merge_block = inst->words[inst->operands[0].offset]; - uint32_t continue_block = inst->words[inst->operands[1].offset]; - CFG_ASSERT(MergeBlockAssert, merge_block); - - if (auto error = _.current_function().RegisterLoopMerge(merge_block, - continue_block)) - return error; - } break; - case SpvOpSelectionMerge: { - uint32_t merge_block = inst->words[inst->operands[0].offset]; - CFG_ASSERT(MergeBlockAssert, merge_block); - - if (auto error = _.current_function().RegisterSelectionMerge(merge_block)) - return error; - } break; - case SpvOpBranch: { - uint32_t target = inst->words[inst->operands[0].offset]; - CFG_ASSERT(FirstBlockAssert, target); - - _.current_function().RegisterBlockEnd({target}, opcode); - } break; - case SpvOpBranchConditional: { - uint32_t tlabel = inst->words[inst->operands[1].offset]; - uint32_t flabel = inst->words[inst->operands[2].offset]; - CFG_ASSERT(FirstBlockAssert, tlabel); - CFG_ASSERT(FirstBlockAssert, flabel); - - _.current_function().RegisterBlockEnd({tlabel, flabel}, opcode); - } break; - - case SpvOpSwitch: { - vector cases; - for (int i = 1; i < inst->num_operands; i += 2) { - uint32_t target = inst->words[inst->operands[i].offset]; - CFG_ASSERT(FirstBlockAssert, target); - cases.push_back(target); - } - _.current_function().RegisterBlockEnd({cases}, opcode); - } break; - case SpvOpReturn: { - const uint32_t return_type = _.current_function().GetResultTypeId(); - const Instruction* return_type_inst = _.FindDef(return_type); - assert(return_type_inst); - if (return_type_inst->opcode() != SpvOpTypeVoid) - return _.diag(SPV_ERROR_INVALID_CFG) - << "OpReturn can only be called from a function with void " - << "return type."; - } - // Fallthrough. - case SpvOpKill: - case SpvOpReturnValue: - case SpvOpUnreachable: - _.current_function().RegisterBlockEnd(vector(), opcode); - if (opcode == SpvOpKill) { - _.current_function().RegisterExecutionModelLimitation( - SpvExecutionModelFragment, - "OpKill requires Fragment execution model"); - } - break; - default: - break; - } - return SPV_SUCCESS; -} -} // namespace libspirv diff --git a/3rdparty/spirv-tools/source/validate_composites.cpp b/3rdparty/spirv-tools/source/validate_composites.cpp deleted file mode 100644 index c163786e5..000000000 --- a/3rdparty/spirv-tools/source/validate_composites.cpp +++ /dev/null @@ -1,481 +0,0 @@ -// Copyright (c) 2017 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Validates correctness of composite SPIR-V instructions. - -#include "validate.h" - -#include "diagnostic.h" -#include "opcode.h" -#include "val/instruction.h" -#include "val/validation_state.h" - -namespace libspirv { - -namespace { - -// Returns the type of the value accessed by OpCompositeExtract or -// OpCompositeInsert instruction. The function traverses the hierarchy of -// nested data structures (structs, arrays, vectors, matrices) as directed by -// the sequence of indices in the instruction. May return error if traversal -// fails (encountered non-composite, out of bounds, nesting too deep). -// Returns the type of Composite operand if the instruction has no indices. -spv_result_t GetExtractInsertValueType(ValidationState_t& _, - const spv_parsed_instruction_t& inst, - uint32_t* member_type) { - const SpvOp opcode = static_cast(inst.opcode); - assert(opcode == SpvOpCompositeExtract || opcode == SpvOpCompositeInsert); - uint32_t word_index = opcode == SpvOpCompositeExtract ? 4 : 5; - const uint32_t num_words = static_cast(inst.num_words); - const uint32_t composite_id_index = word_index - 1; - - const uint32_t num_indices = num_words - word_index; - const uint32_t kCompositeExtractInsertMaxNumIndices = 255; - if (num_indices > kCompositeExtractInsertMaxNumIndices) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "The number of indexes in Op" << spvOpcodeString(opcode) - << " may not exceed " << kCompositeExtractInsertMaxNumIndices - << ". Found " << num_indices << " indexes."; - } - - *member_type = _.GetTypeId(inst.words[composite_id_index]); - if (*member_type == 0) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Composite to be an object of composite type"; - } - - for (; word_index < num_words; ++word_index) { - const uint32_t component_index = inst.words[word_index]; - const Instruction* const type_inst = _.FindDef(*member_type); - assert(type_inst); - switch (type_inst->opcode()) { - case SpvOpTypeVector: { - *member_type = type_inst->word(2); - const uint32_t vector_size = type_inst->word(3); - if (component_index >= vector_size) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": vector access is out of bounds, vector size is " - << vector_size << ", but access index is " << component_index; - } - break; - } - case SpvOpTypeMatrix: { - *member_type = type_inst->word(2); - const uint32_t num_cols = type_inst->word(3); - if (component_index >= num_cols) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": matrix access is out of bounds, matrix has " << num_cols - << " columns, but access index is " << component_index; - } - break; - } - case SpvOpTypeArray: { - uint64_t array_size = 0; - auto size = _.FindDef(type_inst->word(3)); - *member_type = type_inst->word(2); - if (spvOpcodeIsSpecConstant(size->opcode())) { - // Cannot verify against the size of this array. - break; - } - - if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) { - assert(0 && "Array type definition is corrupt"); - } - if (component_index >= array_size) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": array access is out of bounds, array size is " - << array_size << ", but access index is " << component_index; - } - break; - } - case SpvOpTypeRuntimeArray: { - *member_type = type_inst->word(2); - // Array size is unknown. - break; - } - case SpvOpTypeStruct: { - const size_t num_struct_members = type_inst->words().size() - 2; - if (component_index >= num_struct_members) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Index is out of bounds: Op" << spvOpcodeString(opcode) - << " can not find index " << component_index - << " into the structure '" << type_inst->id() - << "'. This structure has " << num_struct_members - << " members. Largest valid index is " - << num_struct_members - 1 << "."; - } - *member_type = type_inst->word(component_index + 2); - break; - } - default: - return _.diag(SPV_ERROR_INVALID_DATA) - << "Op" << spvOpcodeString(opcode) - << " reached non-composite type while indexes still remain to " - "be traversed."; - } - } - - return SPV_SUCCESS; -} - -} // anonymous namespace - -// Validates correctness of composite instructions. -spv_result_t CompositesPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - const uint32_t result_type = inst->type_id; - const uint32_t num_operands = static_cast(inst->num_operands); - - switch (opcode) { - case SpvOpVectorExtractDynamic: { - const SpvOp result_opcode = _.GetIdOpcode(result_type); - if (!spvOpcodeIsScalarType(result_opcode)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Result Type to be a scalar type"; - } - - const uint32_t vector_type = _.GetOperandTypeId(inst, 2); - const SpvOp vector_opcode = _.GetIdOpcode(vector_type); - if (vector_opcode != SpvOpTypeVector) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Vector type to be OpTypeVector"; - } - - if (_.GetComponentType(vector_type) != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Vector component type to be equal to Result Type"; - } - - const uint32_t index_type = _.GetOperandTypeId(inst, 3); - if (!_.IsIntScalarType(index_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Index to be int scalar"; - } - - break; - } - - case SpvOpVectorInsertDynamic: { - const SpvOp result_opcode = _.GetIdOpcode(result_type); - if (result_opcode != SpvOpTypeVector) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Result Type to be OpTypeVector"; - } - - const uint32_t vector_type = _.GetOperandTypeId(inst, 2); - if (vector_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Vector type to be equal to Result Type"; - } - - const uint32_t component_type = _.GetOperandTypeId(inst, 3); - if (_.GetComponentType(result_type) != component_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Component type to be equal to Result Type " - << "component type"; - } - - const uint32_t index_type = _.GetOperandTypeId(inst, 4); - if (!_.IsIntScalarType(index_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Index to be int scalar"; - } - - break; - } - - case SpvOpVectorShuffle: { - // Handled in validate_id.cpp. - // TODO(atgoo@github.com) Consider moving it here. - break; - } - - case SpvOpCompositeConstruct: { - const SpvOp result_opcode = _.GetIdOpcode(result_type); - switch (result_opcode) { - case SpvOpTypeVector: { - const uint32_t num_result_components = _.GetDimension(result_type); - const uint32_t result_component_type = - _.GetComponentType(result_type); - uint32_t given_component_count = 0; - - if (num_operands <= 3) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected number of constituents to be at least 2"; - } - - for (uint32_t operand_index = 2; operand_index < num_operands; - ++operand_index) { - const uint32_t operand_type = - _.GetOperandTypeId(inst, operand_index); - if (operand_type == result_component_type) { - ++given_component_count; - } else { - if (_.GetIdOpcode(operand_type) != SpvOpTypeVector || - _.GetComponentType(operand_type) != result_component_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Constituents to be scalars or vectors of " - << "the same type as Result Type components"; - } - - given_component_count += _.GetDimension(operand_type); - } - } - - if (num_result_components != given_component_count) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected total number of given components to be equal " - << "to the size of Result Type vector"; - } - - break; - } - case SpvOpTypeMatrix: { - uint32_t result_num_rows = 0; - uint32_t result_num_cols = 0; - uint32_t result_col_type = 0; - uint32_t result_component_type = 0; - if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, - &result_num_cols, &result_col_type, - &result_component_type)) { - assert(0); - } - - if (result_num_cols + 2 != num_operands) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected total number of Constituents to be equal " - << "to the number of columns of Result Type matrix"; - } - - for (uint32_t operand_index = 2; operand_index < num_operands; - ++operand_index) { - const uint32_t operand_type = - _.GetOperandTypeId(inst, operand_index); - if (operand_type != result_col_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Constituent type to be equal to the column " - << "type Result Type matrix"; - } - } - - break; - } - case SpvOpTypeArray: { - const Instruction* const array_inst = _.FindDef(result_type); - assert(array_inst); - assert(array_inst->opcode() == SpvOpTypeArray); - - auto size = _.FindDef(array_inst->word(3)); - if (spvOpcodeIsSpecConstant(size->opcode())) { - // Cannot verify against the size of this array. - break; - } - - uint64_t array_size = 0; - if (!_.GetConstantValUint64(array_inst->word(3), &array_size)) { - assert(0 && "Array type definition is corrupt"); - } - - if (array_size + 2 != num_operands) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected total number of Constituents to be equal " - << "to the number of elements of Result Type array"; - } - - const uint32_t result_component_type = array_inst->word(2); - for (uint32_t operand_index = 2; operand_index < num_operands; - ++operand_index) { - const uint32_t operand_type = - _.GetOperandTypeId(inst, operand_index); - if (operand_type != result_component_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Constituent type to be equal to the column " - << "type Result Type array"; - } - } - - break; - } - case SpvOpTypeStruct: { - const Instruction* const struct_inst = _.FindDef(result_type); - assert(struct_inst); - assert(struct_inst->opcode() == SpvOpTypeStruct); - - if (struct_inst->operands().size() + 1 != num_operands) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected total number of Constituents to be equal " - << "to the number of members of Result Type struct"; - } - - for (uint32_t operand_index = 2; operand_index < num_operands; - ++operand_index) { - const uint32_t operand_type = - _.GetOperandTypeId(inst, operand_index); - const uint32_t member_type = struct_inst->word(operand_index); - if (operand_type != member_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Constituent type to be equal to the " - << "corresponding member type of Result Type struct"; - } - } - - break; - } - default: { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Result Type to be a composite type"; - } - } - - break; - } - - case SpvOpCompositeExtract: { - uint32_t member_type = 0; - if (spv_result_t error = - GetExtractInsertValueType(_, *inst, &member_type)) { - return error; - } - - if (result_type != member_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Op" << spvOpcodeString(opcode) << " result type (Op" - << spvOpcodeString(_.GetIdOpcode(result_type)) - << ") does not match the type that results from indexing into " - "the " - "composite (Op" - << spvOpcodeString(_.GetIdOpcode(member_type)) << ")."; - } - break; - } - - case SpvOpCompositeInsert: { - const uint32_t object_type = _.GetOperandTypeId(inst, 2); - const uint32_t composite_type = _.GetOperandTypeId(inst, 3); - - if (result_type != composite_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "The Result Type must be the same as Composite type in Op" - << spvOpcodeString(opcode) << " yielding Result Id " - << result_type << "."; - } - - uint32_t member_type = 0; - if (spv_result_t error = - GetExtractInsertValueType(_, *inst, &member_type)) { - return error; - } - - if (object_type != member_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "The Object type (Op" - << spvOpcodeString(_.GetIdOpcode(object_type)) << ") in Op" - << spvOpcodeString(opcode) - << " does not match the type that results from indexing into " - "the Composite (Op" - << spvOpcodeString(_.GetIdOpcode(member_type)) << ")."; - } - break; - } - - case SpvOpCopyObject: { - if (!spvOpcodeGeneratesType(_.GetIdOpcode(result_type))) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Result Type to be a type"; - } - - const uint32_t operand_type = _.GetOperandTypeId(inst, 2); - if (operand_type != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Result Type and Operand type to be the same"; - } - - break; - } - - case SpvOpTranspose: { - uint32_t result_num_rows = 0; - uint32_t result_num_cols = 0; - uint32_t result_col_type = 0; - uint32_t result_component_type = 0; - if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, &result_num_cols, - &result_col_type, &result_component_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Result Type to be a matrix type"; - } - - const uint32_t matrix_type = _.GetOperandTypeId(inst, 2); - uint32_t matrix_num_rows = 0; - uint32_t matrix_num_cols = 0; - uint32_t matrix_col_type = 0; - uint32_t matrix_component_type = 0; - if (!_.GetMatrixTypeInfo(matrix_type, &matrix_num_rows, &matrix_num_cols, - &matrix_col_type, &matrix_component_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Matrix to be of type OpTypeMatrix"; - } - - if (result_component_type != matrix_component_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected component types of Matrix and Result Type to be " - << "identical"; - } - - if (result_num_rows != matrix_num_cols || - result_num_cols != matrix_num_rows) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected number of columns and the column size of Matrix " - << "to be the reverse of those of Result Type"; - } - - break; - } - - default: - break; - } - - return SPV_SUCCESS; -} - -} // namespace libspirv diff --git a/3rdparty/spirv-tools/source/validate_decorations.cpp b/3rdparty/spirv-tools/source/validate_decorations.cpp deleted file mode 100644 index 11d3ebca9..000000000 --- a/3rdparty/spirv-tools/source/validate_decorations.cpp +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright (c) 2017 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "validate.h" - -#include -#include - -#include "diagnostic.h" -#include "opcode.h" -#include "spirv_target_env.h" -#include "val/validation_state.h" - -using libspirv::Decoration; -using libspirv::DiagnosticStream; -using libspirv::Instruction; -using libspirv::ValidationState_t; - -namespace { - -// Returns whether the given variable has a BuiltIn decoration. -bool isBuiltInVar(uint32_t var_id, ValidationState_t& vstate) { - const auto& decorations = vstate.id_decorations(var_id); - return std::any_of( - decorations.begin(), decorations.end(), - [](const Decoration& d) { return SpvDecorationBuiltIn == d.dec_type(); }); -} - -// Returns whether the given structure type has any members with BuiltIn -// decoration. -bool isBuiltInStruct(uint32_t struct_id, ValidationState_t& vstate) { - const auto& decorations = vstate.id_decorations(struct_id); - return std::any_of( - decorations.begin(), decorations.end(), [](const Decoration& d) { - return SpvDecorationBuiltIn == d.dec_type() && - Decoration::kInvalidMember != d.struct_member_index(); - }); -} - -// Returns true if the given ID has the Import LinkageAttributes decoration. -bool hasImportLinkageAttribute(uint32_t id, ValidationState_t& vstate) { - const auto& decorations = vstate.id_decorations(id); - return std::any_of(decorations.begin(), decorations.end(), - [](const Decoration& d) { - return SpvDecorationLinkageAttributes == d.dec_type() && - d.params().size() >= 2u && - d.params().back() == SpvLinkageTypeImport; - }); -} - -spv_result_t CheckLinkageAttrOfFunctions(ValidationState_t& vstate) { - for (const auto& function : vstate.functions()) { - if (function.block_count() == 0u) { - // A function declaration (an OpFunction with no basic blocks), must have - // a Linkage Attributes Decoration with the Import Linkage Type. - if (!hasImportLinkageAttribute(function.id(), vstate)) { - return vstate.diag(SPV_ERROR_INVALID_BINARY) - << "Function declaration (id " << function.id() - << ") must have a LinkageAttributes decoration with the Import " - "Linkage type."; - } - } else { - if (hasImportLinkageAttribute(function.id(), vstate)) { - return vstate.diag(SPV_ERROR_INVALID_BINARY) - << "Function definition (id " << function.id() - << ") may not be decorated with Import Linkage type."; - } - } - } - return SPV_SUCCESS; -} - -// Checks whether an imported variable is initialized by this module. -spv_result_t CheckImportedVariableInitialization(ValidationState_t& vstate) { - // According the SPIR-V Spec 2.16.1, it is illegal to initialize an imported - // variable. This means that a module-scope OpVariable with initialization - // value cannot be marked with the Import Linkage Type (import type id = 1). - for (auto global_var_id : vstate.global_vars()) { - // Initializer is an optional argument for OpVariable. If initializer - // is present, the instruction will have 5 words. - auto variable_instr = vstate.FindDef(global_var_id); - if (variable_instr->words().size() == 5u && - hasImportLinkageAttribute(global_var_id, vstate)) { - return vstate.diag(SPV_ERROR_INVALID_ID) - << "A module-scope OpVariable with initialization value " - "cannot be marked with the Import Linkage Type."; - } - } - return SPV_SUCCESS; -} - -// Checks whether a builtin variable is valid. -spv_result_t CheckBuiltInVariable(uint32_t var_id, ValidationState_t& vstate) { - const auto& decorations = vstate.id_decorations(var_id); - for (const auto& d : decorations) { - if (spvIsVulkanEnv(vstate.context()->target_env)) { - if (d.dec_type() == SpvDecorationLocation || - d.dec_type() == SpvDecorationComponent) { - return vstate.diag(SPV_ERROR_INVALID_ID) - << "A BuiltIn variable (id " << var_id - << ") cannot have any Location or Component decorations"; - } - } - } - return SPV_SUCCESS; -} - -// Checks whether proper decorations have been appied to the entry points. -spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) { - for (uint32_t entry_point : vstate.entry_points()) { - const auto& interfaces = vstate.entry_point_interfaces(entry_point); - int num_builtin_inputs = 0; - int num_builtin_outputs = 0; - for (auto interface : interfaces) { - Instruction* var_instr = vstate.FindDef(interface); - if (SpvOpVariable != var_instr->opcode()) { - return vstate.diag(SPV_ERROR_INVALID_ID) - << "Interfaces passed to OpEntryPoint must be of type " - "OpTypeVariable. Found Op" - << spvOpcodeString(static_cast(var_instr->opcode())) - << "."; - } - const uint32_t ptr_id = var_instr->word(1); - Instruction* ptr_instr = vstate.FindDef(ptr_id); - // It is guaranteed (by validator ID checks) that ptr_instr is - // OpTypePointer. Word 3 of this instruction is the type being pointed to. - const uint32_t type_id = ptr_instr->word(3); - Instruction* type_instr = vstate.FindDef(type_id); - const auto storage_class = - static_cast(var_instr->word(3)); - if (storage_class != SpvStorageClassInput && - storage_class != SpvStorageClassOutput) { - return vstate.diag(SPV_ERROR_INVALID_ID) - << "OpEntryPoint interfaces must be OpVariables with " - "Storage Class of Input(1) or Output(3). Found Storage Class " - << storage_class << " for Entry Point id " << entry_point << "."; - } - if (type_instr && SpvOpTypeStruct == type_instr->opcode() && - isBuiltInStruct(type_id, vstate)) { - if (storage_class == SpvStorageClassInput) ++num_builtin_inputs; - if (storage_class == SpvStorageClassOutput) ++num_builtin_outputs; - if (num_builtin_inputs > 1 || num_builtin_outputs > 1) break; - if (auto error = CheckBuiltInVariable(interface, vstate)) return error; - } else if (isBuiltInVar(interface, vstate)) { - if (auto error = CheckBuiltInVariable(interface, vstate)) return error; - } - } - if (num_builtin_inputs > 1 || num_builtin_outputs > 1) { - return vstate.diag(SPV_ERROR_INVALID_BINARY) - << "There must be at most one object per Storage Class that can " - "contain a structure type containing members decorated with " - "BuiltIn, consumed per entry-point. Entry Point id " - << entry_point << " does not meet this requirement."; - } - // The LinkageAttributes Decoration cannot be applied to functions targeted - // by an OpEntryPoint instruction - for (auto& decoration : vstate.id_decorations(entry_point)) { - if (SpvDecorationLinkageAttributes == decoration.dec_type()) { - const char* linkage_name = - reinterpret_cast(&decoration.params()[0]); - return vstate.diag(SPV_ERROR_INVALID_BINARY) - << "The LinkageAttributes Decoration (Linkage name: " - << linkage_name << ") cannot be applied to function id " - << entry_point - << " because it is targeted by an OpEntryPoint instruction."; - } - } - } - return SPV_SUCCESS; -} - -} // anonymous namespace - -namespace libspirv { - -// Validates that decorations have been applied properly. -spv_result_t ValidateDecorations(ValidationState_t& vstate) { - if (auto error = CheckImportedVariableInitialization(vstate)) return error; - if (auto error = CheckDecorationsOfEntryPoints(vstate)) return error; - if (auto error = CheckLinkageAttrOfFunctions(vstate)) return error; - return SPV_SUCCESS; -} - -} // namespace libspirv diff --git a/3rdparty/spirv-tools/source/validate_id.cpp b/3rdparty/spirv-tools/source/validate_id.cpp deleted file mode 100644 index b305f496f..000000000 --- a/3rdparty/spirv-tools/source/validate_id.cpp +++ /dev/null @@ -1,2767 +0,0 @@ -// Copyright (c) 2015-2016 The Khronos Group Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "validate.h" - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "diagnostic.h" -#include "instruction.h" -#include "message.h" -#include "opcode.h" -#include "operand.h" -#include "spirv-tools/libspirv.h" -#include "spirv_validator_options.h" -#include "val/function.h" -#include "val/validation_state.h" - -using libspirv::Decoration; -using libspirv::Function; -using libspirv::ValidationState_t; -using std::function; -using std::ignore; -using std::make_pair; -using std::pair; -using std::unordered_set; -using std::vector; - -namespace { - -class idUsage { - public: - idUsage(spv_const_context context, const spv_instruction_t* pInsts, - const uint64_t instCountArg, const SpvMemoryModel memoryModelArg, - const SpvAddressingModel addressingModelArg, - const ValidationState_t& module, const vector& entry_points, - spv_position positionArg, const spvtools::MessageConsumer& consumer) - : targetEnv(context->target_env), - opcodeTable(context->opcode_table), - operandTable(context->operand_table), - extInstTable(context->ext_inst_table), - firstInst(pInsts), - instCount(instCountArg), - memoryModel(memoryModelArg), - addressingModel(addressingModelArg), - position(positionArg), - consumer_(consumer), - module_(module), - entry_points_(entry_points) {} - - bool isValid(const spv_instruction_t* inst); - - template - bool isValid(const spv_instruction_t* inst, const spv_opcode_desc); - - private: - const spv_target_env targetEnv; - const spv_opcode_table opcodeTable; - const spv_operand_table operandTable; - const spv_ext_inst_table extInstTable; - const spv_instruction_t* const firstInst; - const uint64_t instCount; - const SpvMemoryModel memoryModel; - const SpvAddressingModel addressingModel; - spv_position position; - const spvtools::MessageConsumer& consumer_; - const ValidationState_t& module_; - vector entry_points_; - - // Returns true if the two instructions represent structs that, as far as the - // validator can tell, have the exact same data layout. - bool AreLayoutCompatibleStructs(const libspirv::Instruction* type1, - const libspirv::Instruction* type2); - - // Returns true if the operands to the OpTypeStruct instruction defining the - // types are the same or are layout compatible types. |type1| and |type2| must - // be OpTypeStruct instructions. - bool HaveLayoutCompatibleMembers(const libspirv::Instruction* type1, - const libspirv::Instruction* type2); - - // Returns true if all decorations that affect the data layout of the struct - // (like Offset), are the same for the two types. |type1| and |type2| must be - // OpTypeStruct instructions. - bool HaveSameLayoutDecorations(const libspirv::Instruction* type1, - const libspirv::Instruction* type2); - bool HasConflictingMemberOffsets( - const vector& type1_decorations, - const vector& type2_decorations) const; -}; - -#define DIAG(INDEX) \ - position->index += INDEX; \ - libspirv::DiagnosticStream helper(*position, consumer_, \ - SPV_ERROR_INVALID_DIAGNOSTIC); \ - helper - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc) { - assert(0 && "Unimplemented!"); - return false; -} -#endif // 0 - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto typeIndex = 1; - auto type = module_.FindDef(inst->words[typeIndex]); - if (!type || SpvOpTypeStruct != type->opcode()) { - DIAG(typeIndex) << "OpMemberName Type '" << inst->words[typeIndex] - << "' is not a struct type."; - return false; - } - auto memberIndex = 2; - auto member = inst->words[memberIndex]; - auto memberCount = (uint32_t)(type->words().size() - 2); - if (memberCount <= member) { - DIAG(memberIndex) << "OpMemberName Member '" - << inst->words[memberIndex] - << "' index is larger than Type '" << type->id() - << "'s member count."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto fileIndex = 1; - auto file = module_.FindDef(inst->words[fileIndex]); - if (!file || SpvOpString != file->opcode()) { - DIAG(fileIndex) << "OpLine Target '" << inst->words[fileIndex] - << "' is not an OpString."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto decorationIndex = 2; - auto decoration = inst->words[decorationIndex]; - if (decoration == SpvDecorationSpecId) { - auto targetIndex = 1; - auto target = module_.FindDef(inst->words[targetIndex]); - if (!target || !spvOpcodeIsScalarSpecConstant(target->opcode())) { - DIAG(targetIndex) << "OpDecorate SpectId decoration target '" - << inst->words[decorationIndex] - << "' is not a scalar specialization constant."; - return false; - } - } - // TODO: Add validations for all decorations. - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto structTypeIndex = 1; - auto structType = module_.FindDef(inst->words[structTypeIndex]); - if (!structType || SpvOpTypeStruct != structType->opcode()) { - DIAG(structTypeIndex) << "OpMemberDecorate Structure type '" - << inst->words[structTypeIndex] - << "' is not a struct type."; - return false; - } - auto memberIndex = 2; - auto member = inst->words[memberIndex]; - auto memberCount = static_cast(structType->words().size() - 2); - if (memberCount < member) { - DIAG(memberIndex) << "Index " << member - << " provided in OpMemberDecorate for struct " - << inst->words[structTypeIndex] - << " is out of bounds. The structure has " << memberCount - << " members. Largest valid index is " << memberCount - 1 - << "."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto decorationGroupIndex = 1; - auto decorationGroup = module_.FindDef(inst->words[decorationGroupIndex]); - - for (auto pair : decorationGroup->uses()) { - auto use = pair.first; - if (use->opcode() != SpvOpDecorate && use->opcode() != SpvOpGroupDecorate && - use->opcode() != SpvOpGroupMemberDecorate && - use->opcode() != SpvOpName) { - DIAG(decorationGroupIndex) << "Result id of OpDecorationGroup can only " - << "be targeted by OpName, OpGroupDecorate, " - << "OpDecorate, and OpGroupMemberDecorate"; - return false; - } - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto decorationGroupIndex = 1; - auto decorationGroup = module_.FindDef(inst->words[decorationGroupIndex]); - if (!decorationGroup || SpvOpDecorationGroup != decorationGroup->opcode()) { - DIAG(decorationGroupIndex) - << "OpGroupDecorate Decoration group '" - << inst->words[decorationGroupIndex] << "' is not a decoration group."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto decorationGroupIndex = 1; - auto decorationGroup = module_.FindDef(inst->words[decorationGroupIndex]); - if (!decorationGroup || SpvOpDecorationGroup != decorationGroup->opcode()) { - DIAG(decorationGroupIndex) - << "OpGroupMemberDecorate Decoration group '" - << inst->words[decorationGroupIndex] << "' is not a decoration group."; - return false; - } - // Grammar checks ensures that the number of arguments to this instruction - // is an odd number: 1 decoration group + (id,literal) pairs. - for (size_t i = 2; i + 1 < inst->words.size(); i = i + 2) { - const uint32_t struct_id = inst->words[i]; - const uint32_t index = inst->words[i + 1]; - auto struct_instr = module_.FindDef(struct_id); - if (!struct_instr || SpvOpTypeStruct != struct_instr->opcode()) { - DIAG(i) << "OpGroupMemberDecorate Structure type '" << struct_id - << "' is not a struct type."; - return false; - } - const uint32_t num_struct_members = - static_cast(struct_instr->words().size() - 2); - if (index >= num_struct_members) { - DIAG(i) << "Index " << index - << " provided in OpGroupMemberDecorate for struct " - << struct_id << " is out of bounds. The structure has " - << num_struct_members << " members. Largest valid index is " - << num_struct_members - 1 << "."; - return false; - } - } - return true; -} - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif // 0 - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto entryPointIndex = 2; - auto entryPoint = module_.FindDef(inst->words[entryPointIndex]); - if (!entryPoint || SpvOpFunction != entryPoint->opcode()) { - DIAG(entryPointIndex) << "OpEntryPoint Entry Point '" - << inst->words[entryPointIndex] - << "' is not a function."; - return false; - } - // don't check kernel function signatures - const SpvExecutionModel executionModel = SpvExecutionModel(inst->words[1]); - if (executionModel != SpvExecutionModelKernel) { - // TODO: Check the entry point signature is void main(void), may be subject - // to change - auto entryPointType = module_.FindDef(entryPoint->words()[4]); - if (!entryPointType || 3 != entryPointType->words().size()) { - DIAG(entryPointIndex) - << "OpEntryPoint Entry Point '" << inst->words[entryPointIndex] - << "'s function parameter count is not zero."; - return false; - } - } - - std::stack call_stack; - std::set visited; - call_stack.push(entryPoint->id()); - while (!call_stack.empty()) { - const uint32_t called_func_id = call_stack.top(); - call_stack.pop(); - if (!visited.insert(called_func_id).second) continue; - - const Function* called_func = module_.function(called_func_id); - assert(called_func); - - std::string reason; - if (!called_func->IsCompatibleWithExecutionModel(executionModel, &reason)) { - DIAG(entryPointIndex) - << "OpEntryPoint Entry Point '" << inst->words[entryPointIndex] - << "'s callgraph contains function " << called_func_id - << ", which cannot be used with the current execution model:\n" - << reason; - return false; - } - - for (uint32_t new_call : called_func->function_call_targets()) { - call_stack.push(new_call); - } - } - - auto returnType = module_.FindDef(entryPoint->type_id()); - if (!returnType || SpvOpTypeVoid != returnType->opcode()) { - DIAG(entryPointIndex) << "OpEntryPoint Entry Point '" - << inst->words[entryPointIndex] - << "'s function return type is not void."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto entryPointIndex = 1; - auto entryPointID = inst->words[entryPointIndex]; - auto found = - std::find(entry_points_.cbegin(), entry_points_.cend(), entryPointID); - if (found == entry_points_.cend()) { - DIAG(entryPointIndex) << "OpExecutionMode Entry Point '" - << inst->words[entryPointIndex] - << "' is not the Entry Point " - "operand of an OpEntryPoint."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto componentIndex = 2; - auto componentType = module_.FindDef(inst->words[componentIndex]); - if (!componentType || !spvOpcodeIsScalarType(componentType->opcode())) { - DIAG(componentIndex) << "OpTypeVector Component Type '" - << inst->words[componentIndex] - << "' is not a scalar type."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto columnTypeIndex = 2; - auto columnType = module_.FindDef(inst->words[columnTypeIndex]); - if (!columnType || SpvOpTypeVector != columnType->opcode()) { - DIAG(columnTypeIndex) << "OpTypeMatrix Column Type '" - << inst->words[columnTypeIndex] - << "' is not a vector."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t*, - const spv_opcode_desc) { - // OpTypeSampler takes no arguments in Rev31 and beyond. - return true; -} - -// True if the integer constant is > 0. constWords are words of the -// constant-defining instruction (either OpConstant or -// OpSpecConstant). typeWords are the words of the constant's-type-defining -// OpTypeInt. -bool aboveZero(const vector& constWords, - const vector& typeWords) { - const uint32_t width = typeWords[2]; - const bool is_signed = typeWords[3] > 0; - const uint32_t loWord = constWords[3]; - if (width > 32) { - // The spec currently doesn't allow integers wider than 64 bits. - const uint32_t hiWord = constWords[4]; // Must exist, per spec. - if (is_signed && (hiWord >> 31)) return false; - return (loWord | hiWord) > 0; - } else { - if (is_signed && (loWord >> 31)) return false; - return loWord > 0; - } -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto elementTypeIndex = 2; - auto elementType = module_.FindDef(inst->words[elementTypeIndex]); - if (!elementType || !spvOpcodeGeneratesType(elementType->opcode())) { - DIAG(elementTypeIndex) << "OpTypeArray Element Type '" - << inst->words[elementTypeIndex] - << "' is not a type."; - return false; - } - auto lengthIndex = 3; - auto length = module_.FindDef(inst->words[lengthIndex]); - if (!length || !spvOpcodeIsConstant(length->opcode())) { - DIAG(lengthIndex) << "OpTypeArray Length '" << inst->words[lengthIndex] - << "' is not a scalar constant type."; - return false; - } - - // NOTE: Check the initialiser value of the constant - auto constInst = length->words(); - auto constResultTypeIndex = 1; - auto constResultType = module_.FindDef(constInst[constResultTypeIndex]); - if (!constResultType || SpvOpTypeInt != constResultType->opcode()) { - DIAG(lengthIndex) << "OpTypeArray Length '" << inst->words[lengthIndex] - << "' is not a constant integer type."; - return false; - } - - switch (length->opcode()) { - case SpvOpSpecConstant: - case SpvOpConstant: - if (aboveZero(length->words(), constResultType->words())) break; - // Else fall through! - case SpvOpConstantNull: { - DIAG(lengthIndex) << "OpTypeArray Length '" - << inst->words[lengthIndex] - << "' default value must be at least 1."; - return false; - } - case SpvOpSpecConstantOp: - // Assume it's OK, rather than try to evaluate the operation. - break; - default: - assert(0 && "bug in spvOpcodeIsConstant() or result type isn't int"); - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto elementTypeIndex = 2; - auto elementType = module_.FindDef(inst->words[elementTypeIndex]); - if (!elementType || !spvOpcodeGeneratesType(elementType->opcode())) { - DIAG(elementTypeIndex) << "OpTypeRuntimeArray Element Type '" - << inst->words[elementTypeIndex] - << "' is not a type."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - ValidationState_t& vstate = const_cast(module_); - const uint32_t struct_id = inst->words[1]; - for (size_t memberTypeIndex = 2; memberTypeIndex < inst->words.size(); - ++memberTypeIndex) { - auto memberTypeId = inst->words[memberTypeIndex]; - auto memberType = module_.FindDef(memberTypeId); - if (!memberType || !spvOpcodeGeneratesType(memberType->opcode())) { - DIAG(memberTypeIndex) - << "OpTypeStruct Member Type '" << inst->words[memberTypeIndex] - << "' is not a type."; - return false; - } - if (SpvOpTypeStruct == memberType->opcode() && - module_.IsStructTypeWithBuiltInMember(memberTypeId)) { - DIAG(memberTypeIndex) - << "Structure " << memberTypeId - << " contains members with BuiltIn decoration. Therefore this " - "structure may not be contained as a member of another structure " - "type. Structure " - << struct_id << " contains structure " << memberTypeId << "."; - return false; - } - if (module_.IsForwardPointer(memberTypeId)) { - if (memberType->opcode() != SpvOpTypePointer) { - DIAG(memberTypeIndex) << "Found a forward reference to a non-pointer " - "type in OpTypeStruct instruction."; - return false; - } - // If we're dealing with a forward pointer: - // Find out the type that the pointer is pointing to (must be struct) - // word 3 is the of the type being pointed to. - auto typePointingTo = module_.FindDef(memberType->words()[3]); - if (typePointingTo && typePointingTo->opcode() != SpvOpTypeStruct) { - // Forward declared operands of a struct may only point to a struct. - DIAG(memberTypeIndex) - << "A forward reference operand in an OpTypeStruct must be an " - "OpTypePointer that points to an OpTypeStruct. " - "Found OpTypePointer that points to Op" - << spvOpcodeString(static_cast(typePointingTo->opcode())) - << "."; - return false; - } - } - } - std::unordered_set built_in_members; - for (auto decoration : vstate.id_decorations(struct_id)) { - if (decoration.dec_type() == SpvDecorationBuiltIn && - decoration.struct_member_index() != Decoration::kInvalidMember) { - built_in_members.insert(decoration.struct_member_index()); - } - } - int num_struct_members = static_cast(inst->words.size() - 2); - int num_builtin_members = static_cast(built_in_members.size()); - if (num_builtin_members > 0 && num_builtin_members != num_struct_members) { - DIAG(0) - << "When BuiltIn decoration is applied to a structure-type member, " - "all members of that structure type must also be decorated with " - "BuiltIn (No allowed mixing of built-in variables and " - "non-built-in variables within a single structure). Structure id " - << struct_id << " does not meet this requirement."; - return false; - } - if (num_builtin_members > 0) { - vstate.RegisterStructTypeWithBuiltInMember(struct_id); - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto typeIndex = 3; - auto type = module_.FindDef(inst->words[typeIndex]); - if (!type || !spvOpcodeGeneratesType(type->opcode())) { - DIAG(typeIndex) << "OpTypePointer Type '" << inst->words[typeIndex] - << "' is not a type."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto returnTypeIndex = 2; - auto returnType = module_.FindDef(inst->words[returnTypeIndex]); - if (!returnType || !spvOpcodeGeneratesType(returnType->opcode())) { - DIAG(returnTypeIndex) << "OpTypeFunction Return Type '" - << inst->words[returnTypeIndex] << "' is not a type."; - return false; - } - size_t num_args = 0; - for (size_t paramTypeIndex = 3; paramTypeIndex < inst->words.size(); - ++paramTypeIndex, ++num_args) { - auto paramType = module_.FindDef(inst->words[paramTypeIndex]); - if (!paramType || !spvOpcodeGeneratesType(paramType->opcode())) { - DIAG(paramTypeIndex) << "OpTypeFunction Parameter Type '" - << inst->words[paramTypeIndex] << "' is not a type."; - return false; - } - } - const uint32_t num_function_args_limit = - module_.options()->universal_limits_.max_function_args; - if (num_args > num_function_args_limit) { - DIAG(returnTypeIndex) << "OpTypeFunction may not take more than " - << num_function_args_limit - << " arguments. OpTypeFunction '" - << inst->words[1] << "' has " << num_args - << " arguments."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t*, - const spv_opcode_desc) { - // OpTypePipe has no ID arguments. - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType || SpvOpTypeBool != resultType->opcode()) { - DIAG(resultTypeIndex) << "OpConstantTrue Result Type '" - << inst->words[resultTypeIndex] - << "' is not a boolean type."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType || SpvOpTypeBool != resultType->opcode()) { - DIAG(resultTypeIndex) << "OpConstantFalse Result Type '" - << inst->words[resultTypeIndex] - << "' is not a boolean type."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType || !spvOpcodeIsComposite(resultType->opcode())) { - DIAG(resultTypeIndex) << "OpConstantComposite Result Type '" - << inst->words[resultTypeIndex] - << "' is not a composite type."; - return false; - } - - auto constituentCount = inst->words.size() - 3; - switch (resultType->opcode()) { - case SpvOpTypeVector: { - auto componentCount = resultType->words()[3]; - if (componentCount != constituentCount) { - // TODO: Output ID's on diagnostic - DIAG(inst->words.size() - 1) - << "OpConstantComposite Constituent count does not match " - "Result Type '" - << resultType->id() << "'s vector component count."; - return false; - } - auto componentType = module_.FindDef(resultType->words()[2]); - assert(componentType); - for (size_t constituentIndex = 3; constituentIndex < inst->words.size(); - constituentIndex++) { - auto constituent = module_.FindDef(inst->words[constituentIndex]); - if (!constituent || - !spvOpcodeIsConstantOrUndef(constituent->opcode())) { - DIAG(constituentIndex) << "OpConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' is not a constant or undef."; - return false; - } - auto constituentResultType = module_.FindDef(constituent->type_id()); - if (!constituentResultType || - componentType->opcode() != constituentResultType->opcode()) { - DIAG(constituentIndex) - << "OpConstantComposite Constituent '" - << inst->words[constituentIndex] - << "'s type does not match Result Type '" << resultType->id() - << "'s vector element type."; - return false; - } - } - } break; - case SpvOpTypeMatrix: { - auto columnCount = resultType->words()[3]; - if (columnCount != constituentCount) { - // TODO: Output ID's on diagnostic - DIAG(inst->words.size() - 1) - << "OpConstantComposite Constituent count does not match " - "Result Type '" - << resultType->id() << "'s matrix column count."; - return false; - } - - auto columnType = module_.FindDef(resultType->words()[2]); - assert(columnType); - auto componentCount = columnType->words()[3]; - auto componentType = module_.FindDef(columnType->words()[2]); - assert(componentType); - - for (size_t constituentIndex = 3; constituentIndex < inst->words.size(); - constituentIndex++) { - auto constituent = module_.FindDef(inst->words[constituentIndex]); - if (!constituent || !(SpvOpConstantComposite == constituent->opcode() || - SpvOpUndef == constituent->opcode())) { - // The message says "... or undef" because the spec does not say - // undef is a constant. - DIAG(constituentIndex) << "OpConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' is not a constant composite or undef."; - return false; - } - auto vector = module_.FindDef(constituent->type_id()); - assert(vector); - if (columnType->opcode() != vector->opcode()) { - DIAG(constituentIndex) - << "OpConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' type does not match Result Type '" << resultType->id() - << "'s matrix column type."; - return false; - } - auto vectorComponentType = module_.FindDef(vector->words()[2]); - assert(vectorComponentType); - if (componentType->id() != vectorComponentType->id()) { - DIAG(constituentIndex) - << "OpConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' component type does not match Result Type '" - << resultType->id() << "'s matrix column component type."; - return false; - } - if (componentCount != vector->words()[3]) { - DIAG(constituentIndex) - << "OpConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' vector component count does not match Result Type '" - << resultType->id() << "'s vector component count."; - return false; - } - } - } break; - case SpvOpTypeArray: { - auto elementType = module_.FindDef(resultType->words()[2]); - assert(elementType); - auto length = module_.FindDef(resultType->words()[3]); - assert(length); - if (length->words()[3] != constituentCount) { - DIAG(inst->words.size() - 1) - << "OpConstantComposite Constituent count does not match " - "Result Type '" - << resultType->id() << "'s array length."; - return false; - } - for (size_t constituentIndex = 3; constituentIndex < inst->words.size(); - constituentIndex++) { - auto constituent = module_.FindDef(inst->words[constituentIndex]); - if (!constituent || - !spvOpcodeIsConstantOrUndef(constituent->opcode())) { - DIAG(constituentIndex) << "OpConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' is not a constant or undef."; - return false; - } - auto constituentType = module_.FindDef(constituent->type_id()); - assert(constituentType); - if (elementType->id() != constituentType->id()) { - DIAG(constituentIndex) - << "OpConstantComposite Constituent '" - << inst->words[constituentIndex] - << "'s type does not match Result Type '" << resultType->id() - << "'s array element type."; - return false; - } - } - } break; - case SpvOpTypeStruct: { - auto memberCount = resultType->words().size() - 2; - if (memberCount != constituentCount) { - DIAG(resultTypeIndex) << "OpConstantComposite Constituent '" - << inst->words[resultTypeIndex] - << "' count does not match Result Type '" - << resultType->id() << "'s struct member count."; - return false; - } - for (uint32_t constituentIndex = 3, memberIndex = 2; - constituentIndex < inst->words.size(); - constituentIndex++, memberIndex++) { - auto constituent = module_.FindDef(inst->words[constituentIndex]); - if (!constituent || - !spvOpcodeIsConstantOrUndef(constituent->opcode())) { - DIAG(constituentIndex) << "OpConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' is not a constant or undef."; - return false; - } - auto constituentType = module_.FindDef(constituent->type_id()); - assert(constituentType); - - auto memberType = module_.FindDef(resultType->words()[memberIndex]); - assert(memberType); - if (memberType->id() != constituentType->id()) { - DIAG(constituentIndex) - << "OpConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' type does not match the Result Type '" - << resultType->id() << "'s member type."; - return false; - } - } - } break; - default: { assert(0 && "Unreachable!"); } break; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType || SpvOpTypeSampler != resultType->opcode()) { - DIAG(resultTypeIndex) << "OpConstantSampler Result Type '" - << inst->words[resultTypeIndex] - << "' is not a sampler type."; - return false; - } - return true; -} - -// True if instruction defines a type that can have a null value, as defined by -// the SPIR-V spec. Tracks composite-type components through module to check -// nullability transitively. -bool IsTypeNullable(const vector& instruction, - const ValidationState_t& module) { - uint16_t opcode; - uint16_t word_count; - spvOpcodeSplit(instruction[0], &word_count, &opcode); - switch (static_cast(opcode)) { - case SpvOpTypeBool: - case SpvOpTypeInt: - case SpvOpTypeFloat: - case SpvOpTypePointer: - case SpvOpTypeEvent: - case SpvOpTypeDeviceEvent: - case SpvOpTypeReserveId: - case SpvOpTypeQueue: - return true; - case SpvOpTypeArray: - case SpvOpTypeMatrix: - case SpvOpTypeVector: { - auto base_type = module.FindDef(instruction[2]); - return base_type && IsTypeNullable(base_type->words(), module); - } - case SpvOpTypeStruct: { - for (size_t elementIndex = 2; elementIndex < instruction.size(); - ++elementIndex) { - auto element = module.FindDef(instruction[elementIndex]); - if (!element || !IsTypeNullable(element->words(), module)) return false; - } - return true; - } - default: - return false; - } -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType || !IsTypeNullable(resultType->words(), module_)) { - DIAG(resultTypeIndex) << "OpConstantNull Result Type '" - << inst->words[resultTypeIndex] - << "' cannot have a null value."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType || SpvOpTypeBool != resultType->opcode()) { - DIAG(resultTypeIndex) << "OpSpecConstantTrue Result Type '" - << inst->words[resultTypeIndex] - << "' is not a boolean type."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType || SpvOpTypeBool != resultType->opcode()) { - DIAG(resultTypeIndex) << "OpSpecConstantFalse Result Type '" - << inst->words[resultTypeIndex] - << "' is not a boolean type."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto resultTypeIndex = 2; - auto resultID = inst->words[resultTypeIndex]; - auto sampledImageInstr = module_.FindDef(resultID); - // We need to validate 2 things: - // * All OpSampledImage instructions must be in the same block in which their - // Result are consumed. - // * Result from OpSampledImage instructions must not appear as operands - // to OpPhi instructions or OpSelect instructions, or any instructions other - // than the image lookup and query instructions specified to take an operand - // whose type is OpTypeSampledImage. - std::vector consumers = module_.getSampledImageConsumers(resultID); - if (!consumers.empty()) { - for (auto consumer_id : consumers) { - auto consumer_instr = module_.FindDef(consumer_id); - auto consumer_opcode = consumer_instr->opcode(); - if (consumer_instr->block() != sampledImageInstr->block()) { - DIAG(resultTypeIndex) - << "All OpSampledImage instructions must be in the same block in " - "which their Result are consumed. OpSampledImage Result " - "Type '" - << resultID - << "' has a consumer in a different basic " - "block. The consumer instruction is '" - << consumer_id << "'."; - return false; - } - // TODO: The following check is incomplete. We should also check that the - // Sampled Image is not used by instructions that should not take - // SampledImage as an argument. We could find the list of valid - // instructions by scanning for "Sampled Image" in the operand description - // field in the grammar file. - if (consumer_opcode == SpvOpPhi || consumer_opcode == SpvOpSelect) { - DIAG(resultTypeIndex) - << "Result from OpSampledImage instruction must not appear as " - "operands of Op" - << spvOpcodeString(static_cast(consumer_opcode)) << "." - << " Found result '" << resultID << "' as an operand of '" - << consumer_id << "'."; - return false; - } - } - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - // The result type must be a composite type. - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType || !spvOpcodeIsComposite(resultType->opcode())) { - DIAG(resultTypeIndex) << "OpSpecConstantComposite Result Type '" - << inst->words[resultTypeIndex] - << "' is not a composite type."; - return false; - } - // Validation checks differ based on the type of composite type. - auto constituentCount = inst->words.size() - 3; - switch (resultType->opcode()) { - // For Vectors, the following must be met: - // * Number of constituents in the result type and the vector must match. - // * All the components of the vector must have the same type (or specialize - // to the same type). OpConstant and OpSpecConstant are allowed. - // To check that condition, we check each supplied value argument's type - // against the element type of the result type. - case SpvOpTypeVector: { - auto componentCount = resultType->words()[3]; - if (componentCount != constituentCount) { - DIAG(inst->words.size() - 1) - << "OpSpecConstantComposite Constituent count does not match " - "Result Type '" - << resultType->id() << "'s vector component count."; - return false; - } - auto componentType = module_.FindDef(resultType->words()[2]); - assert(componentType); - for (size_t constituentIndex = 3; constituentIndex < inst->words.size(); - constituentIndex++) { - auto constituent = module_.FindDef(inst->words[constituentIndex]); - if (!constituent || - !spvOpcodeIsConstantOrUndef(constituent->opcode())) { - DIAG(constituentIndex) << "OpSpecConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' is not a constant or undef."; - return false; - } - auto constituentResultType = module_.FindDef(constituent->type_id()); - if (!constituentResultType || - componentType->opcode() != constituentResultType->opcode()) { - DIAG(constituentIndex) - << "OpSpecConstantComposite Constituent '" - << inst->words[constituentIndex] - << "'s type does not match Result Type '" << resultType->id() - << "'s vector element type."; - return false; - } - } - break; - } - case SpvOpTypeMatrix: { - auto columnCount = resultType->words()[3]; - if (columnCount != constituentCount) { - DIAG(inst->words.size() - 1) - << "OpSpecConstantComposite Constituent count does not match " - "Result Type '" - << resultType->id() << "'s matrix column count."; - return false; - } - - auto columnType = module_.FindDef(resultType->words()[2]); - assert(columnType); - auto componentCount = columnType->words()[3]; - auto componentType = module_.FindDef(columnType->words()[2]); - assert(componentType); - - for (size_t constituentIndex = 3; constituentIndex < inst->words.size(); - constituentIndex++) { - auto constituent = module_.FindDef(inst->words[constituentIndex]); - auto constituentOpCode = constituent->opcode(); - if (!constituent || !(SpvOpSpecConstantComposite == constituentOpCode || - SpvOpConstantComposite == constituentOpCode || - SpvOpUndef == constituentOpCode)) { - // The message says "... or undef" because the spec does not say - // undef is a constant. - DIAG(constituentIndex) << "OpSpecConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' is not a constant composite or undef."; - return false; - } - auto vector = module_.FindDef(constituent->type_id()); - assert(vector); - if (columnType->opcode() != vector->opcode()) { - DIAG(constituentIndex) - << "OpSpecConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' type does not match Result Type '" << resultType->id() - << "'s matrix column type."; - return false; - } - auto vectorComponentType = module_.FindDef(vector->words()[2]); - assert(vectorComponentType); - if (componentType->id() != vectorComponentType->id()) { - DIAG(constituentIndex) - << "OpSpecConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' component type does not match Result Type '" - << resultType->id() << "'s matrix column component type."; - return false; - } - if (componentCount != vector->words()[3]) { - DIAG(constituentIndex) - << "OpSpecConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' vector component count does not match Result Type '" - << resultType->id() << "'s vector component count."; - return false; - } - } - break; - } - case SpvOpTypeArray: { - auto elementType = module_.FindDef(resultType->words()[2]); - assert(elementType); - auto length = module_.FindDef(resultType->words()[3]); - assert(length); - if (length->words()[3] != constituentCount) { - DIAG(inst->words.size() - 1) - << "OpSpecConstantComposite Constituent count does not match " - "Result Type '" - << resultType->id() << "'s array length."; - return false; - } - for (size_t constituentIndex = 3; constituentIndex < inst->words.size(); - constituentIndex++) { - auto constituent = module_.FindDef(inst->words[constituentIndex]); - if (!constituent || - !spvOpcodeIsConstantOrUndef(constituent->opcode())) { - DIAG(constituentIndex) << "OpSpecConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' is not a constant or undef."; - return false; - } - auto constituentType = module_.FindDef(constituent->type_id()); - assert(constituentType); - if (elementType->id() != constituentType->id()) { - DIAG(constituentIndex) - << "OpSpecConstantComposite Constituent '" - << inst->words[constituentIndex] - << "'s type does not match Result Type '" << resultType->id() - << "'s array element type."; - return false; - } - } - break; - } - case SpvOpTypeStruct: { - auto memberCount = resultType->words().size() - 2; - if (memberCount != constituentCount) { - DIAG(resultTypeIndex) << "OpSpecConstantComposite Constituent '" - << inst->words[resultTypeIndex] - << "' count does not match Result Type '" - << resultType->id() << "'s struct member count."; - return false; - } - for (uint32_t constituentIndex = 3, memberIndex = 2; - constituentIndex < inst->words.size(); - constituentIndex++, memberIndex++) { - auto constituent = module_.FindDef(inst->words[constituentIndex]); - if (!constituent || - !spvOpcodeIsConstantOrUndef(constituent->opcode())) { - DIAG(constituentIndex) << "OpSpecConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' is not a constant or undef."; - return false; - } - auto constituentType = module_.FindDef(constituent->type_id()); - assert(constituentType); - - auto memberType = module_.FindDef(resultType->words()[memberIndex]); - assert(memberType); - if (memberType->id() != constituentType->id()) { - DIAG(constituentIndex) - << "OpSpecConstantComposite Constituent '" - << inst->words[constituentIndex] - << "' type does not match the Result Type '" - << resultType->id() << "'s member type."; - return false; - } - } - break; - } - default: { assert(0 && "Unreachable!"); } break; - } - return true; -} - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst) {} -#endif - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType || SpvOpTypePointer != resultType->opcode()) { - DIAG(resultTypeIndex) << "OpVariable Result Type '" - << inst->words[resultTypeIndex] - << "' is not a pointer type."; - return false; - } - const auto initialiserIndex = 4; - if (initialiserIndex < inst->words.size()) { - const auto initialiser = module_.FindDef(inst->words[initialiserIndex]); - const auto storageClassIndex = 3; - const auto is_module_scope_var = - initialiser && (initialiser->opcode() == SpvOpVariable) && - (initialiser->word(storageClassIndex) != SpvStorageClassFunction); - const auto is_constant = - initialiser && spvOpcodeIsConstant(initialiser->opcode()); - if (!initialiser || !(is_constant || is_module_scope_var)) { - DIAG(initialiserIndex) - << "OpVariable Initializer '" << inst->words[initialiserIndex] - << "' is not a constant or module-scope variable."; - return false; - } - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType) { - DIAG(resultTypeIndex) << "OpLoad Result Type '" - << inst->words[resultTypeIndex] << "' is not defind."; - return false; - } - const bool uses_variable_pointer = - module_.features().variable_pointers || - module_.features().variable_pointers_storage_buffer; - auto pointerIndex = 3; - auto pointer = module_.FindDef(inst->words[pointerIndex]); - if (!pointer || - (addressingModel == SpvAddressingModelLogical && - ((!uses_variable_pointer && - !spvOpcodeReturnsLogicalPointer(pointer->opcode())) || - (uses_variable_pointer && - !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) { - DIAG(pointerIndex) << "OpLoad Pointer '" << inst->words[pointerIndex] - << "' is not a logical pointer."; - return false; - } - auto pointerType = module_.FindDef(pointer->type_id()); - if (!pointerType || pointerType->opcode() != SpvOpTypePointer) { - DIAG(pointerIndex) << "OpLoad type for pointer '" - << inst->words[pointerIndex] - << "' is not a pointer type."; - return false; - } - auto pointeeType = module_.FindDef(pointerType->words()[3]); - if (!pointeeType || resultType->id() != pointeeType->id()) { - DIAG(resultTypeIndex) << "OpLoad Result Type '" - << inst->words[resultTypeIndex] - << "' does not match Pointer '" << pointer->id() - << "'s type."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - const bool uses_variable_pointer = - module_.features().variable_pointers || - module_.features().variable_pointers_storage_buffer; - const auto pointerIndex = 1; - auto pointer = module_.FindDef(inst->words[pointerIndex]); - if (!pointer || - (addressingModel == SpvAddressingModelLogical && - ((!uses_variable_pointer && - !spvOpcodeReturnsLogicalPointer(pointer->opcode())) || - (uses_variable_pointer && - !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) { - DIAG(pointerIndex) << "OpStore Pointer '" << inst->words[pointerIndex] - << "' is not a logical pointer."; - return false; - } - auto pointerType = module_.FindDef(pointer->type_id()); - if (!pointer || pointerType->opcode() != SpvOpTypePointer) { - DIAG(pointerIndex) << "OpStore type for pointer '" - << inst->words[pointerIndex] - << "' is not a pointer type."; - return false; - } - auto type = module_.FindDef(pointerType->words()[3]); - assert(type); - if (SpvOpTypeVoid == type->opcode()) { - DIAG(pointerIndex) << "OpStore Pointer '" << inst->words[pointerIndex] - << "'s type is void."; - return false; - } - - // validate storage class - { - uint32_t dataType; - uint32_t storageClass; - if (!module_.GetPointerTypeInfo(pointerType->id(), &dataType, - &storageClass)) { - DIAG(pointerIndex) << "OpStore Pointer '" - << inst->words[pointerIndex] - << "' is not pointer type"; - return false; - } - - if (storageClass == SpvStorageClassUniformConstant || - storageClass == SpvStorageClassInput || - storageClass == SpvStorageClassPushConstant) { - DIAG(pointerIndex) << "OpStore Pointer '" - << inst->words[pointerIndex] - << "' storage class is read-only"; - return false; - } - } - - auto objectIndex = 2; - auto object = module_.FindDef(inst->words[objectIndex]); - if (!object || !object->type_id()) { - DIAG(objectIndex) << "OpStore Object '" << inst->words[objectIndex] - << "' is not an object."; - return false; - } - auto objectType = module_.FindDef(object->type_id()); - assert(objectType); - if (SpvOpTypeVoid == objectType->opcode()) { - DIAG(objectIndex) << "OpStore Object '" << inst->words[objectIndex] - << "'s type is void."; - return false; - } - - if (type->id() != objectType->id()) { - if (!module_.options()->relax_struct_store || - type->opcode() != SpvOpTypeStruct || - objectType->opcode() != SpvOpTypeStruct) { - DIAG(pointerIndex) << "OpStore Pointer '" - << inst->words[pointerIndex] - << "'s type does not match Object '" - << object->id() << "'s type."; - return false; - } - - // TODO: Check for layout compatible matricies and arrays as well. - if (!AreLayoutCompatibleStructs(type, objectType)) { - DIAG(pointerIndex) << "OpStore Pointer '" - << inst->words[pointerIndex] - << "'s layout does not match Object '" - << object->id() << "'s layout."; - return false; - } - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto targetIndex = 1; - auto target = module_.FindDef(inst->words[targetIndex]); - if (!target) return false; - auto sourceIndex = 2; - auto source = module_.FindDef(inst->words[sourceIndex]); - if (!source) return false; - auto targetPointerType = module_.FindDef(target->type_id()); - assert(targetPointerType); - auto targetType = module_.FindDef(targetPointerType->words()[3]); - assert(targetType); - auto sourcePointerType = module_.FindDef(source->type_id()); - assert(sourcePointerType); - auto sourceType = module_.FindDef(sourcePointerType->words()[3]); - assert(sourceType); - if (targetType->id() != sourceType->id()) { - DIAG(sourceIndex) << "OpCopyMemory Target '" - << inst->words[sourceIndex] - << "'s type does not match Source '" - << sourceType->id() << "'s type."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto targetIndex = 1; - auto target = module_.FindDef(inst->words[targetIndex]); - if (!target) return false; - auto sourceIndex = 2; - auto source = module_.FindDef(inst->words[sourceIndex]); - if (!source) return false; - auto sizeIndex = 3; - auto size = module_.FindDef(inst->words[sizeIndex]); - if (!size) return false; - auto targetPointerType = module_.FindDef(target->type_id()); - if (!targetPointerType || SpvOpTypePointer != targetPointerType->opcode()) { - DIAG(targetIndex) << "OpCopyMemorySized Target '" - << inst->words[targetIndex] << "' is not a pointer."; - return false; - } - auto sourcePointerType = module_.FindDef(source->type_id()); - if (!sourcePointerType || SpvOpTypePointer != sourcePointerType->opcode()) { - DIAG(sourceIndex) << "OpCopyMemorySized Source '" - << inst->words[sourceIndex] << "' is not a pointer."; - return false; - } - switch (size->opcode()) { - // TODO: The following opcode's are assumed to be valid, refer to the - // following bug https://cvs.khronos.org/bugzilla/show_bug.cgi?id=13871 for - // clarification - case SpvOpConstant: - case SpvOpSpecConstant: { - auto sizeType = module_.FindDef(size->type_id()); - assert(sizeType); - if (SpvOpTypeInt != sizeType->opcode()) { - DIAG(sizeIndex) << "OpCopyMemorySized Size '" - << inst->words[sizeIndex] - << "'s type is not an integer type."; - return false; - } - } break; - case SpvOpVariable: { - auto pointerType = module_.FindDef(size->type_id()); - assert(pointerType); - auto sizeType = module_.FindDef(pointerType->type_id()); - if (!sizeType || SpvOpTypeInt != sizeType->opcode()) { - DIAG(sizeIndex) << "OpCopyMemorySized Size '" - << inst->words[sizeIndex] - << "'s variable type is not an integer type."; - return false; - } - } break; - default: - DIAG(sizeIndex) << "OpCopyMemorySized Size '" - << inst->words[sizeIndex] - << "' is not a constant or variable."; - return false; - } - // TODO: Check that consant is a least size 1, see the same bug as above for - // clarification? - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - std::string instr_name = - "Op" + std::string(spvOpcodeString(static_cast(inst->opcode))); - - // The result type must be OpTypePointer. Result Type is at word 1. - auto resultTypeIndex = 1; - auto resultTypeInstr = module_.FindDef(inst->words[resultTypeIndex]); - if (SpvOpTypePointer != resultTypeInstr->opcode()) { - DIAG(resultTypeIndex) << "The Result Type of " << instr_name << " '" - << inst->words[2] - << "' must be OpTypePointer. Found Op" - << spvOpcodeString( - static_cast(resultTypeInstr->opcode())) - << "."; - return false; - } - - // Result type is a pointer. Find out what it's pointing to. - // This will be used to make sure the indexing results in the same type. - // OpTypePointer word 3 is the type being pointed to. - auto resultTypePointedTo = module_.FindDef(resultTypeInstr->word(3)); - - // Base must be a pointer, pointing to the base of a composite object. - auto baseIdIndex = 3; - auto baseInstr = module_.FindDef(inst->words[baseIdIndex]); - auto baseTypeInstr = module_.FindDef(baseInstr->type_id()); - if (!baseTypeInstr || SpvOpTypePointer != baseTypeInstr->opcode()) { - DIAG(baseIdIndex) << "The Base '" << inst->words[baseIdIndex] - << "' in " << instr_name - << " instruction must be a pointer."; - return false; - } - - // The result pointer storage class and base pointer storage class must match. - // Word 2 of OpTypePointer is the Storage Class. - auto resultTypeStorageClass = resultTypeInstr->word(2); - auto baseTypeStorageClass = baseTypeInstr->word(2); - if (resultTypeStorageClass != baseTypeStorageClass) { - DIAG(resultTypeIndex) << "The result pointer storage class and base " - "pointer storage class in " - << instr_name << " do not match."; - return false; - } - - // The type pointed to by OpTypePointer (word 3) must be a composite type. - auto typePointedTo = module_.FindDef(baseTypeInstr->word(3)); - - // Check Universal Limit (SPIR-V Spec. Section 2.17). - // The number of indexes passed to OpAccessChain may not exceed 255 - // The instruction includes 4 words + N words (for N indexes) - const size_t num_indexes = inst->words.size() - 4; - const size_t num_indexes_limit = - module_.options()->universal_limits_.max_access_chain_indexes; - if (num_indexes > num_indexes_limit) { - DIAG(resultTypeIndex) << "The number of indexes in " << instr_name - << " may not exceed " << num_indexes_limit - << ". Found " << num_indexes << " indexes."; - return false; - } - // Indexes walk the type hierarchy to the desired depth, potentially down to - // scalar granularity. The first index in Indexes will select the top-level - // member/element/component/element of the base composite. All composite - // constituents use zero-based numbering, as described by their OpType... - // instruction. The second index will apply similarly to that result, and so - // on. Once any non-composite type is reached, there must be no remaining - // (unused) indexes. - for (size_t i = 4; i < inst->words.size(); ++i) { - const uint32_t cur_word = inst->words[i]; - // Earlier ID checks ensure that cur_word definition exists. - auto cur_word_instr = module_.FindDef(cur_word); - // The index must be a scalar integer type (See OpAccessChain in the Spec.) - auto indexTypeInstr = module_.FindDef(cur_word_instr->type_id()); - if (!indexTypeInstr || SpvOpTypeInt != indexTypeInstr->opcode()) { - DIAG(i) << "Indexes passed to " << instr_name - << " must be of type integer."; - return false; - } - switch (typePointedTo->opcode()) { - case SpvOpTypeMatrix: - case SpvOpTypeVector: - case SpvOpTypeArray: - case SpvOpTypeRuntimeArray: { - // In OpTypeMatrix, OpTypeVector, OpTypeArray, and OpTypeRuntimeArray, - // word 2 is the Element Type. - typePointedTo = module_.FindDef(typePointedTo->word(2)); - break; - } - case SpvOpTypeStruct: { - // In case of structures, there is an additional constraint on the - // index: the index must be an OpConstant. - if (SpvOpConstant != cur_word_instr->opcode()) { - DIAG(i) << "The passed to " << instr_name - << " to index into a " - "structure must be an OpConstant."; - return false; - } - // Get the index value from the OpConstant (word 3 of OpConstant). - // OpConstant could be a signed integer. But it's okay to treat it as - // unsigned because a negative constant int would never be seen as - // correct as a struct offset, since structs can't have more than 2 - // billion members. - const uint32_t cur_index = cur_word_instr->word(3); - // The index points to the struct member we want, therefore, the index - // should be less than the number of struct members. - const uint32_t num_struct_members = - static_cast(typePointedTo->words().size() - 2); - if (cur_index >= num_struct_members) { - DIAG(i) << "Index is out of bounds: " << instr_name - << " can not find index " << cur_index - << " into the structure '" << typePointedTo->id() - << "'. This structure has " << num_struct_members - << " members. Largest valid index is " - << num_struct_members - 1 << "."; - return false; - } - // Struct members IDs start at word 2 of OpTypeStruct. - auto structMemberId = typePointedTo->word(cur_index + 2); - typePointedTo = module_.FindDef(structMemberId); - break; - } - default: { - // Give an error. reached non-composite type while indexes still remain. - DIAG(i) << instr_name - << " reached non-composite type while indexes " - "still remain to be traversed."; - return false; - } - } - } - // At this point, we have fully walked down from the base using the indeces. - // The type being pointed to should be the same as the result type. - if (typePointedTo->id() != resultTypePointedTo->id()) { - DIAG(resultTypeIndex) - << instr_name << " result type (Op" - << spvOpcodeString(static_cast(resultTypePointedTo->opcode())) - << ") does not match the type that results from indexing into the base " - " (Op" - << spvOpcodeString(static_cast(typePointedTo->opcode())) << ")."; - return false; - } - - return true; -} - -template <> -bool idUsage::isValid( - const spv_instruction_t* inst, const spv_opcode_desc opcodeEntry) { - return isValid(inst, opcodeEntry); -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc opcodeEntry) { - // OpPtrAccessChain's validation rules are similar to OpAccessChain, with one - // difference: word 4 must be id of an integer (Element ). - // The grammar guarantees that there are at least 5 words in the instruction - // (i.e. if there are fewer than 5 words, the SPIR-V code will not compile.) - int elem_index = 4; - // We can remove the Element from the instruction words, and simply call - // the validation code of OpAccessChain. - spv_instruction_t new_inst = *inst; - new_inst.words.erase(new_inst.words.begin() + elem_index); - return isValid(&new_inst, opcodeEntry); -} - -template <> -bool idUsage::isValid( - const spv_instruction_t* inst, const spv_opcode_desc opcodeEntry) { - // Has the same validation rules as OpPtrAccessChain - return isValid(inst, opcodeEntry); -} - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType) return false; - auto functionTypeIndex = 4; - auto functionType = module_.FindDef(inst->words[functionTypeIndex]); - if (!functionType || SpvOpTypeFunction != functionType->opcode()) { - DIAG(functionTypeIndex) - << "OpFunction Function Type '" << inst->words[functionTypeIndex] - << "' is not a function type."; - return false; - } - auto returnType = module_.FindDef(functionType->words()[2]); - assert(returnType); - if (returnType->id() != resultType->id()) { - DIAG(resultTypeIndex) << "OpFunction Result Type '" - << inst->words[resultTypeIndex] - << "' does not match the Function Type '" - << resultType->id() << "'s return type."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType) return false; - // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place. - size_t paramIndex = 0; - assert(firstInst < inst && "Invalid instruction pointer"); - while (firstInst != --inst) { - if (SpvOpFunction == inst->opcode) { - break; - } else if (SpvOpFunctionParameter == inst->opcode) { - paramIndex++; - } - } - auto functionType = module_.FindDef(inst->words[4]); - assert(functionType); - if (paramIndex >= functionType->words().size() - 3) { - DIAG(0) << "Too many OpFunctionParameters for " << inst->words[2] - << ": expected " << functionType->words().size() - 3 - << " based on the function's type"; - return false; - } - auto paramType = module_.FindDef(functionType->words()[paramIndex + 3]); - assert(paramType); - if (resultType->id() != paramType->id()) { - DIAG(resultTypeIndex) << "OpFunctionParameter Result Type '" - << inst->words[resultTypeIndex] - << "' does not match the OpTypeFunction parameter " - "type of the same index."; - return false; - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType) return false; - auto functionIndex = 3; - auto function = module_.FindDef(inst->words[functionIndex]); - if (!function || SpvOpFunction != function->opcode()) { - DIAG(functionIndex) << "OpFunctionCall Function '" - << inst->words[functionIndex] << "' is not a function."; - return false; - } - auto returnType = module_.FindDef(function->type_id()); - assert(returnType); - if (returnType->id() != resultType->id()) { - DIAG(resultTypeIndex) << "OpFunctionCall Result Type '" - << inst->words[resultTypeIndex] - << "'s type does not match Function '" - << returnType->id() << "'s return type."; - return false; - } - auto functionType = module_.FindDef(function->words()[4]); - assert(functionType); - auto functionCallArgCount = inst->words.size() - 4; - auto functionParamCount = functionType->words().size() - 3; - if (functionParamCount != functionCallArgCount) { - DIAG(inst->words.size() - 1) - << "OpFunctionCall Function 's parameter count does not match " - "the argument count."; - return false; - } - for (size_t argumentIndex = 4, paramIndex = 3; - argumentIndex < inst->words.size(); argumentIndex++, paramIndex++) { - auto argument = module_.FindDef(inst->words[argumentIndex]); - if (!argument) return false; - auto argumentType = module_.FindDef(argument->type_id()); - assert(argumentType); - auto parameterType = module_.FindDef(functionType->words()[paramIndex]); - assert(parameterType); - if (argumentType->id() != parameterType->id()) { - DIAG(argumentIndex) << "OpFunctionCall Argument '" - << inst->words[argumentIndex] - << "'s type does not match Function '" - << parameterType->id() << "'s parameter type."; - return false; - } - } - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto instr_name = [&inst]() { - std::string name = - "Op" + std::string(spvOpcodeString(static_cast(inst->opcode))); - return name; - }; - - // Result Type must be an OpTypeVector. - auto resultTypeIndex = 1; - auto resultType = module_.FindDef(inst->words[resultTypeIndex]); - if (!resultType || resultType->opcode() != SpvOpTypeVector) { - DIAG(resultTypeIndex) << "The Result Type of " << instr_name() - << " must be OpTypeVector. Found Op" - << spvOpcodeString( - static_cast(resultType->opcode())) - << "."; - return false; - } - - // The number of components in Result Type must be the same as the number of - // Component operands. - auto componentCount = inst->words.size() - 5; - auto vectorComponentCountIndex = 3; - auto resultVectorDimension = resultType->words()[vectorComponentCountIndex]; - if (componentCount != resultVectorDimension) { - DIAG(inst->words.size() - 1) - << instr_name() - << " component literals count does not match " - "Result Type '" - << resultType->id() << "'s vector component count."; - return false; - } - - // Vector 1 and Vector 2 must both have vector types, with the same Component - // Type as Result Type. - auto vector1Index = 3; - auto vector1Object = module_.FindDef(inst->words[vector1Index]); - auto vector1Type = module_.FindDef(vector1Object->type_id()); - auto vector2Index = 4; - auto vector2Object = module_.FindDef(inst->words[vector2Index]); - auto vector2Type = module_.FindDef(vector2Object->type_id()); - if (!vector1Type || vector1Type->opcode() != SpvOpTypeVector) { - DIAG(vector1Index) << "The type of Vector 1 must be OpTypeVector."; - return false; - } - if (!vector2Type || vector2Type->opcode() != SpvOpTypeVector) { - DIAG(vector2Index) << "The type of Vector 2 must be OpTypeVector."; - return false; - } - auto vectorComponentTypeIndex = 2; - auto resultComponentType = resultType->words()[vectorComponentTypeIndex]; - auto vector1ComponentType = vector1Type->words()[vectorComponentTypeIndex]; - if (vector1ComponentType != resultComponentType) { - DIAG(vector1Index) << "The Component Type of Vector 1 must be the same " - "as ResultType."; - return false; - } - auto vector2ComponentType = vector2Type->words()[vectorComponentTypeIndex]; - if (vector2ComponentType != resultComponentType) { - DIAG(vector2Index) << "The Component Type of Vector 2 must be the same " - "as ResultType."; - return false; - } - - // All Component literals must either be FFFFFFFF or in [0, N - 1]. - auto vector1ComponentCount = vector1Type->words()[vectorComponentCountIndex]; - auto vector2ComponentCount = vector2Type->words()[vectorComponentCountIndex]; - auto N = vector1ComponentCount + vector2ComponentCount; - auto firstLiteralIndex = 5; - for (size_t i = firstLiteralIndex; i < inst->words.size(); ++i) { - auto literal = inst->words[i]; - if (literal != 0xFFFFFFFF && literal >= N) { - DIAG(i) << "Component literal value " << literal << " is greater than " - << N - 1 << "."; - return false; - } - } - - return true; -} - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc /*opcodeEntry*/) { - auto thisInst = module_.FindDef(inst->words[2]); - SpvOp typeOp = module_.GetIdOpcode(thisInst->type_id()); - if (!spvOpcodeGeneratesType(typeOp)) { - DIAG(0) << "OpPhi's type " << module_.getIdName(thisInst->type_id()) - << " is not a type instruction."; - return false; - } - - auto block = thisInst->block(); - size_t numInOps = inst->words.size() - 3; - if (numInOps % 2 != 0) { - DIAG(0) << "OpPhi does not have an equal number of incoming values and " - "basic blocks."; - return false; - } - - // Create a uniqued vector of predecessor ids for comparison against - // incoming values. OpBranchConditional %cond %label %label produces two - // predecessors in the CFG. - std::vector predIds; - std::transform(block->predecessors()->begin(), block->predecessors()->end(), - std::back_inserter(predIds), - [](const libspirv::BasicBlock* b) { return b->id(); }); - std::sort(predIds.begin(), predIds.end()); - predIds.erase(std::unique(predIds.begin(), predIds.end()), predIds.end()); - - size_t numEdges = numInOps / 2; - if (numEdges != predIds.size()) { - DIAG(0) << "OpPhi's number of incoming blocks (" << numEdges - << ") does not match block's predecessor count (" - << block->predecessors()->size() << ")."; - return false; - } - - for (size_t i = 3; i < inst->words.size(); ++i) { - auto incId = inst->words[i]; - if (i % 2 == 1) { - // Incoming value type must match the phi result type. - auto incTypeId = module_.GetTypeId(incId); - if (thisInst->type_id() != incTypeId) { - DIAG(i) << "OpPhi's result type " - << module_.getIdName(thisInst->type_id()) - << " does not match incoming value " - << module_.getIdName(incId) << " type " - << module_.getIdName(incTypeId) << "."; - return false; - } - } else { - if (module_.GetIdOpcode(incId) != SpvOpLabel) { - DIAG(i) << "OpPhi's incoming basic block " - << module_.getIdName(incId) << " is not an OpLabel."; - return false; - } - - // Incoming basic block must be an immediate predecessor of the phi's - // block. - if (!std::binary_search(predIds.begin(), predIds.end(), incId)) { - DIAG(i) << "OpPhi's incoming basic block " - << module_.getIdName(incId) << " is not a predecessor of " - << module_.getIdName(block->id()) << "."; - return false; - } - } - } - - return true; -} - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - const size_t numOperands = inst->words.size() - 1; - const size_t condOperandIndex = 1; - const size_t targetTrueIndex = 2; - const size_t targetFalseIndex = 3; - - // num_operands is either 3 or 5 --- if 5, the last two need to be literal - // integers - if (numOperands != 3 && numOperands != 5) { - DIAG(0) << "OpBranchConditional requires either 3 or 5 parameters"; - return false; - } - - bool ret = true; - - // grab the condition operand and check that it is a bool - const auto condOp = module_.FindDef(inst->words[condOperandIndex]); - if (!condOp || !module_.IsBoolScalarType(condOp->type_id())) { - DIAG(0) - << "Condition operand for OpBranchConditional must be of boolean type"; - ret = false; - } - - // target operands must be OpLabel - // note that we don't need to check that the target labels are in the same - // function, - // PerformCfgChecks already checks for that - const auto targetOpTrue = module_.FindDef(inst->words[targetTrueIndex]); - if (!targetOpTrue || SpvOpLabel != targetOpTrue->opcode()) { - DIAG(0) << "The 'True Label' operand for OpBranchConditional must be the " - "ID of an OpLabel instruction"; - ret = false; - } - - const auto targetOpFalse = module_.FindDef(inst->words[targetFalseIndex]); - if (!targetOpFalse || SpvOpLabel != targetOpFalse->opcode()) { - DIAG(0) << "The 'False Label' operand for OpBranchConditional must be the " - "ID of an OpLabel instruction"; - ret = false; - } - - return ret; -} - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -template <> -bool idUsage::isValid(const spv_instruction_t* inst, - const spv_opcode_desc) { - auto valueIndex = 1; - auto value = module_.FindDef(inst->words[valueIndex]); - if (!value || !value->type_id()) { - DIAG(valueIndex) << "OpReturnValue Value '" << inst->words[valueIndex] - << "' does not represent a value."; - return false; - } - auto valueType = module_.FindDef(value->type_id()); - if (!valueType || SpvOpTypeVoid == valueType->opcode()) { - DIAG(valueIndex) << "OpReturnValue value's type '" << value->type_id() - << "' is missing or void."; - return false; - } - - const bool uses_variable_pointer = - module_.features().variable_pointers || - module_.features().variable_pointers_storage_buffer; - - if (addressingModel == SpvAddressingModelLogical && - SpvOpTypePointer == valueType->opcode() && !uses_variable_pointer && - !module_.options()->relax_logcial_pointer) { - DIAG(valueIndex) - << "OpReturnValue value's type '" << value->type_id() - << "' is a pointer, which is invalid in the Logical addressing model."; - return false; - } - - // NOTE: Find OpFunction - const spv_instruction_t* function = inst - 1; - while (firstInst != function) { - if (SpvOpFunction == function->opcode) break; - function--; - } - if (SpvOpFunction != function->opcode) { - DIAG(valueIndex) << "OpReturnValue is not in a basic block."; - return false; - } - auto returnType = module_.FindDef(function->words[1]); - if (!returnType || returnType->id() != valueType->id()) { - DIAG(valueIndex) << "OpReturnValue Value '" << inst->words[valueIndex] - << "'s type does not match OpFunction's return type."; - return false; - } - return true; -} - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) { -} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) { -} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) { -} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid(const spv_instruction_t *inst, - const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#if 0 -template <> -bool idUsage::isValid( - const spv_instruction_t *inst, const spv_opcode_desc opcodeEntry) {} -#endif - -#undef DIAG - -bool idUsage::isValid(const spv_instruction_t* inst) { - spv_opcode_desc opcodeEntry = nullptr; - if (spvOpcodeTableValueLookup(targetEnv, opcodeTable, inst->opcode, - &opcodeEntry)) - return false; -#define CASE(OpCode) \ - case Spv##OpCode: \ - return isValid(inst, opcodeEntry); -#define TODO(OpCode) \ - case Spv##OpCode: \ - return true; - switch (inst->opcode) { - TODO(OpUndef) - CASE(OpMemberName) - CASE(OpLine) - CASE(OpDecorate) - CASE(OpMemberDecorate) - CASE(OpDecorationGroup) - CASE(OpGroupDecorate) - CASE(OpGroupMemberDecorate) - TODO(OpExtInst) - CASE(OpEntryPoint) - CASE(OpExecutionMode) - CASE(OpTypeVector) - CASE(OpTypeMatrix) - CASE(OpTypeSampler) - CASE(OpTypeArray) - CASE(OpTypeRuntimeArray) - CASE(OpTypeStruct) - CASE(OpTypePointer) - CASE(OpTypeFunction) - CASE(OpTypePipe) - CASE(OpConstantTrue) - CASE(OpConstantFalse) - CASE(OpConstantComposite) - CASE(OpConstantSampler) - CASE(OpConstantNull) - CASE(OpSpecConstantTrue) - CASE(OpSpecConstantFalse) - CASE(OpSpecConstantComposite) - CASE(OpSampledImage) - TODO(OpSpecConstantOp) - CASE(OpVariable) - CASE(OpLoad) - CASE(OpStore) - CASE(OpCopyMemory) - CASE(OpCopyMemorySized) - CASE(OpAccessChain) - CASE(OpInBoundsAccessChain) - CASE(OpPtrAccessChain) - CASE(OpInBoundsPtrAccessChain) - TODO(OpArrayLength) - TODO(OpGenericPtrMemSemantics) - CASE(OpFunction) - CASE(OpFunctionParameter) - CASE(OpFunctionCall) - // Conversion opcodes are validated in validate_conversion.cpp. - CASE(OpVectorShuffle) - // Other composite opcodes are validated in validate_composites.cpp. - // Arithmetic opcodes are validated in validate_arithmetics.cpp. - // Bitwise opcodes are validated in validate_bitwise.cpp. - // Logical opcodes are validated in validate_logicals.cpp. - // Derivative opcodes are validated in validate_derivatives.cpp. - CASE(OpPhi) - TODO(OpLoopMerge) - TODO(OpSelectionMerge) - TODO(OpBranch) - CASE(OpBranchConditional) - TODO(OpSwitch) - CASE(OpReturnValue) - TODO(OpLifetimeStart) - TODO(OpLifetimeStop) - TODO(OpAtomicLoad) - TODO(OpAtomicStore) - TODO(OpAtomicExchange) - TODO(OpAtomicCompareExchange) - TODO(OpAtomicCompareExchangeWeak) - TODO(OpAtomicIIncrement) - TODO(OpAtomicIDecrement) - TODO(OpAtomicIAdd) - TODO(OpAtomicISub) - TODO(OpAtomicUMin) - TODO(OpAtomicUMax) - TODO(OpAtomicAnd) - TODO(OpAtomicOr) - TODO(OpAtomicSMin) - TODO(OpAtomicSMax) - TODO(OpEmitStreamVertex) - TODO(OpEndStreamPrimitive) - TODO(OpGroupAsyncCopy) - TODO(OpGroupWaitEvents) - TODO(OpGroupAll) - TODO(OpGroupAny) - TODO(OpGroupBroadcast) - TODO(OpGroupIAdd) - TODO(OpGroupFAdd) - TODO(OpGroupFMin) - TODO(OpGroupUMin) - TODO(OpGroupSMin) - TODO(OpGroupFMax) - TODO(OpGroupUMax) - TODO(OpGroupSMax) - TODO(OpEnqueueMarker) - TODO(OpEnqueueKernel) - TODO(OpGetKernelNDrangeSubGroupCount) - TODO(OpGetKernelNDrangeMaxSubGroupSize) - TODO(OpGetKernelWorkGroupSize) - TODO(OpGetKernelPreferredWorkGroupSizeMultiple) - TODO(OpRetainEvent) - TODO(OpReleaseEvent) - TODO(OpCreateUserEvent) - TODO(OpIsValidEvent) - TODO(OpSetUserEventStatus) - TODO(OpCaptureEventProfilingInfo) - TODO(OpGetDefaultQueue) - TODO(OpBuildNDRange) - TODO(OpReadPipe) - TODO(OpWritePipe) - TODO(OpReservedReadPipe) - TODO(OpReservedWritePipe) - TODO(OpReserveReadPipePackets) - TODO(OpReserveWritePipePackets) - TODO(OpCommitReadPipe) - TODO(OpCommitWritePipe) - TODO(OpIsValidReserveId) - TODO(OpGetNumPipePackets) - TODO(OpGetMaxPipePackets) - TODO(OpGroupReserveReadPipePackets) - TODO(OpGroupReserveWritePipePackets) - TODO(OpGroupCommitReadPipe) - TODO(OpGroupCommitWritePipe) - default: - return true; - } -#undef TODO -#undef CASE -} - -bool idUsage::AreLayoutCompatibleStructs(const libspirv::Instruction* type1, - const libspirv::Instruction* type2) { - if (type1->opcode() != SpvOpTypeStruct) { - return false; - } - if (type2->opcode() != SpvOpTypeStruct) { - return false; - } - - if (!HaveLayoutCompatibleMembers(type1, type2)) return false; - - return HaveSameLayoutDecorations(type1, type2); -} - -bool idUsage::HaveLayoutCompatibleMembers(const libspirv::Instruction* type1, - const libspirv::Instruction* type2) { - assert(type1->opcode() == SpvOpTypeStruct && - "type1 must be and OpTypeStruct instruction."); - assert(type2->opcode() == SpvOpTypeStruct && - "type2 must be and OpTypeStruct instruction."); - const auto& type1_operands = type1->operands(); - const auto& type2_operands = type2->operands(); - if (type1_operands.size() != type2_operands.size()) { - return false; - } - - for (size_t operand = 2; operand < type1_operands.size(); ++operand) { - if (type1->word(operand) != type2->word(operand)) { - auto def1 = module_.FindDef(type1->word(operand)); - auto def2 = module_.FindDef(type2->word(operand)); - if (!AreLayoutCompatibleStructs(def1, def2)) { - return false; - } - } - } - return true; -} - -bool idUsage::HaveSameLayoutDecorations(const libspirv::Instruction* type1, - const libspirv::Instruction* type2) { - assert(type1->opcode() == SpvOpTypeStruct && - "type1 must be and OpTypeStruct instruction."); - assert(type2->opcode() == SpvOpTypeStruct && - "type2 must be and OpTypeStruct instruction."); - const std::vector& type1_decorations = - module_.id_decorations(type1->id()); - const std::vector& type2_decorations = - module_.id_decorations(type2->id()); - - // TODO: Will have to add other check for arrays an matricies if we want to - // handle them. - if (HasConflictingMemberOffsets(type1_decorations, type2_decorations)) { - return false; - } - - return true; -} - -bool idUsage::HasConflictingMemberOffsets( - const vector& type1_decorations, - const vector& type2_decorations) const { - { - // We are interested in conflicting decoration. If a decoration is in one - // list but not the other, then we will assume the code is correct. We are - // looking for things we know to be wrong. - // - // We do not have to traverse type2_decoration because, after traversing - // type1_decorations, anything new will not be found in - // type1_decoration. Therefore, it cannot lead to a conflict. - for (const Decoration& decoration : type1_decorations) { - switch (decoration.dec_type()) { - case SpvDecorationOffset: { - // Since these affect the layout of the struct, they must be present - // in both structs. - auto compare = [&decoration](const Decoration& rhs) { - if (rhs.dec_type() != SpvDecorationOffset) return false; - return decoration.struct_member_index() == - rhs.struct_member_index(); - }; - auto i = find_if(type2_decorations.begin(), type2_decorations.end(), - compare); - if (i != type2_decorations.end() && - decoration.params().front() != i->params().front()) { - return true; - } - } break; - default: - // This decoration does not affect the layout of the structure, so - // just moving on. - break; - } - } - } - return false; -} -} // anonymous namespace - -namespace libspirv { - -spv_result_t UpdateIdUse(ValidationState_t& _) { - for (const auto& inst : _.ordered_instructions()) { - for (auto& operand : inst.operands()) { - const spv_operand_type_t& type = operand.type; - const uint32_t operand_id = inst.word(operand.offset); - if (spvIsIdType(type) && type != SPV_OPERAND_TYPE_RESULT_ID) { - if (auto def = _.FindDef(operand_id)) - def->RegisterUse(&inst, operand.offset); - } - } - } - return SPV_SUCCESS; -} - -/// This function checks all ID definitions dominate their use in the CFG. -/// -/// This function will iterate over all ID definitions that are defined in the -/// functions of a module and make sure that the definitions appear in a -/// block that dominates their use. -/// -/// NOTE: This function does NOT check module scoped functions which are -/// checked during the initial binary parse in the IdPass below -spv_result_t CheckIdDefinitionDominateUse(const ValidationState_t& _) { - unordered_set phi_instructions; - for (const auto& definition : _.all_definitions()) { - // Check only those definitions defined in a function - if (const Function* func = definition.second->function()) { - if (const BasicBlock* block = definition.second->block()) { - if (!block->reachable()) continue; - // If the Id is defined within a block then make sure all references to - // that Id appear in a blocks that are dominated by the defining block - for (auto& use_index_pair : definition.second->uses()) { - const Instruction* use = use_index_pair.first; - if (const BasicBlock* use_block = use->block()) { - if (use_block->reachable() == false) continue; - if (use->opcode() == SpvOpPhi) { - phi_instructions.insert(use); - } else if (!block->dominates(*use->block())) { - return _.diag(SPV_ERROR_INVALID_ID) - << "ID " << _.getIdName(definition.first) - << " defined in block " << _.getIdName(block->id()) - << " does not dominate its use in block " - << _.getIdName(use_block->id()); - } - } - } - } else { - // If the Ids defined within a function but not in a block(i.e. function - // parameters, block ids), then make sure all references to that Id - // appear within the same function - for (auto use : definition.second->uses()) { - const Instruction* inst = use.first; - if (inst->function() && inst->function() != func) { - return _.diag(SPV_ERROR_INVALID_ID) - << "ID " << _.getIdName(definition.first) - << " used in function " - << _.getIdName(inst->function()->id()) - << " is used outside of it's defining function " - << _.getIdName(func->id()); - } - } - } - } - // NOTE: Ids defined outside of functions must appear before they are used - // This check is being performed in the IdPass function - } - - // Check all OpPhi parent blocks are dominated by the variable's defining - // blocks - for (const Instruction* phi : phi_instructions) { - if (phi->block()->reachable() == false) continue; - for (size_t i = 3; i < phi->operands().size(); i += 2) { - const Instruction* variable = _.FindDef(phi->word(i)); - const BasicBlock* parent = - phi->function()->GetBlock(phi->word(i + 1)).first; - if (variable->block() && parent->reachable() && - !variable->block()->dominates(*parent)) { - return _.diag(SPV_ERROR_INVALID_ID) - << "In OpPhi instruction " << _.getIdName(phi->id()) << ", ID " - << _.getIdName(variable->id()) - << " definition does not dominate its parent " - << _.getIdName(parent->id()); - } - } - } - - return SPV_SUCCESS; -} - -// Performs SSA validation on the IDs of an instruction. The -// can_have_forward_declared_ids functor should return true if the -// instruction operand's ID can be forward referenced. -spv_result_t IdPass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - auto can_have_forward_declared_ids = - spvOperandCanBeForwardDeclaredFunction(static_cast(inst->opcode)); - - // Keep track of a result id defined by this instruction. 0 means it - // does not define an id. - uint32_t result_id = 0; - - for (unsigned i = 0; i < inst->num_operands; i++) { - const spv_parsed_operand_t& operand = inst->operands[i]; - const spv_operand_type_t& type = operand.type; - // We only care about Id operands, which are a single word. - const uint32_t operand_word = inst->words[operand.offset]; - - auto ret = SPV_ERROR_INTERNAL; - switch (type) { - case SPV_OPERAND_TYPE_RESULT_ID: - // NOTE: Multiple Id definitions are being checked by the binary parser. - // - // Defer undefined-forward-reference removal until after we've analyzed - // the remaining operands to this instruction. Deferral only matters - // for - // OpPhi since it's the only case where it defines its own forward - // reference. Other instructions that can have forward references - // either don't define a value or the forward reference is to a function - // Id (and hence defined outside of a function body). - result_id = operand_word; - // NOTE: The result Id is added (in RegisterInstruction) *after* all of - // the other Ids have been checked to avoid premature use in the same - // instruction. - ret = SPV_SUCCESS; - break; - case SPV_OPERAND_TYPE_ID: - case SPV_OPERAND_TYPE_TYPE_ID: - case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: - case SPV_OPERAND_TYPE_SCOPE_ID: - if (_.IsDefinedId(operand_word)) { - ret = SPV_SUCCESS; - } else if (can_have_forward_declared_ids(i)) { - ret = _.ForwardDeclareId(operand_word); - } else { - ret = _.diag(SPV_ERROR_INVALID_ID) - << "ID " << _.getIdName(operand_word) - << " has not been defined"; - } - break; - default: - ret = SPV_SUCCESS; - break; - } - if (SPV_SUCCESS != ret) { - return ret; - } - } - if (result_id) { - _.RemoveIfForwardDeclared(result_id); - } - _.RegisterInstruction(*inst); - return SPV_SUCCESS; -} -} // namespace libspirv - -spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts, - const uint64_t instCount, - const libspirv::ValidationState_t& state, - spv_position position) { - idUsage idUsage(state.context(), pInsts, instCount, state.memory_model(), - state.addressing_model(), state, state.entry_points(), - position, state.context()->consumer); - for (uint64_t instIndex = 0; instIndex < instCount; ++instIndex) { - if (!idUsage.isValid(&pInsts[instIndex])) return SPV_ERROR_INVALID_ID; - position->index += pInsts[instIndex].words.size(); - } - return SPV_SUCCESS; -} diff --git a/3rdparty/spirv-tools/source/validate_image.cpp b/3rdparty/spirv-tools/source/validate_image.cpp deleted file mode 100644 index 1d671b954..000000000 --- a/3rdparty/spirv-tools/source/validate_image.cpp +++ /dev/null @@ -1,1670 +0,0 @@ -// Copyright (c) 2017 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Validates correctness of image instructions. - -#include "validate.h" - -#include "diagnostic.h" -#include "opcode.h" -#include "spirv_target_env.h" -#include "util/bitutils.h" -#include "val/instruction.h" -#include "val/validation_state.h" - -namespace libspirv { - -namespace { - -// Performs compile time check that all SpvImageOperandsXXX cases are handled in -// this module. If SpvImageOperandsXXX list changes, this function will fail the -// build. -// For all other purposes this is a dummy function. -bool CheckAllImageOperandsHandled() { - SpvImageOperandsMask enum_val = SpvImageOperandsBiasMask; - - // Some improvised code to prevent the compiler from considering enum_val - // constant and optimizing the switch away. - uint32_t stack_var = 0; - if (reinterpret_cast(&stack_var) % 256) - enum_val = SpvImageOperandsLodMask; - - switch (enum_val) { - // Please update the validation rules in this module if you are changing - // the list of image operands, and add new enum values to this switch. - case SpvImageOperandsMaskNone: - return false; - case SpvImageOperandsBiasMask: - case SpvImageOperandsLodMask: - case SpvImageOperandsGradMask: - case SpvImageOperandsConstOffsetMask: - case SpvImageOperandsOffsetMask: - case SpvImageOperandsConstOffsetsMask: - case SpvImageOperandsSampleMask: - case SpvImageOperandsMinLodMask: - return true; - } - return false; -} - -// Used by GetImageTypeInfo. See OpTypeImage spec for more information. -struct ImageTypeInfo { - uint32_t sampled_type = 0; - SpvDim dim = SpvDimMax; - uint32_t depth = 0; - uint32_t arrayed = 0; - uint32_t multisampled = 0; - uint32_t sampled = 0; - SpvImageFormat format = SpvImageFormatMax; - SpvAccessQualifier access_qualifier = SpvAccessQualifierMax; -}; - -// Provides information on image type. |id| should be object of either -// OpTypeImage or OpTypeSampledImage type. Returns false in case of failure -// (not a valid id, failed to parse the instruction, etc). -bool GetImageTypeInfo(const ValidationState_t& _, uint32_t id, - ImageTypeInfo* info) { - if (!id || !info) return false; - - const Instruction* inst = _.FindDef(id); - assert(inst); - - if (inst->opcode() == SpvOpTypeSampledImage) { - inst = _.FindDef(inst->word(2)); - assert(inst); - } - - if (inst->opcode() != SpvOpTypeImage) return false; - - const size_t num_words = inst->words().size(); - if (num_words != 9 && num_words != 10) return false; - - info->sampled_type = inst->word(2); - info->dim = static_cast(inst->word(3)); - info->depth = inst->word(4); - info->arrayed = inst->word(5); - info->multisampled = inst->word(6); - info->sampled = inst->word(7); - info->format = static_cast(inst->word(8)); - info->access_qualifier = num_words < 10 - ? SpvAccessQualifierMax - : static_cast(inst->word(9)); - return true; -} - -bool IsImplicitLod(SpvOp opcode) { - switch (opcode) { - case SpvOpImageSampleImplicitLod: - case SpvOpImageSampleDrefImplicitLod: - case SpvOpImageSampleProjImplicitLod: - case SpvOpImageSampleProjDrefImplicitLod: - case SpvOpImageSparseSampleImplicitLod: - case SpvOpImageSparseSampleDrefImplicitLod: - case SpvOpImageSparseSampleProjImplicitLod: - case SpvOpImageSparseSampleProjDrefImplicitLod: - return true; - default: - break; - }; - return false; -} - -bool IsExplicitLod(SpvOp opcode) { - switch (opcode) { - case SpvOpImageSampleExplicitLod: - case SpvOpImageSampleDrefExplicitLod: - case SpvOpImageSampleProjExplicitLod: - case SpvOpImageSampleProjDrefExplicitLod: - case SpvOpImageSparseSampleExplicitLod: - case SpvOpImageSparseSampleDrefExplicitLod: - case SpvOpImageSparseSampleProjExplicitLod: - case SpvOpImageSparseSampleProjDrefExplicitLod: - return true; - default: - break; - }; - return false; -} - -// Returns true if the opcode is a Image instruction which applies -// homogenous projection to the coordinates. -bool IsProj(SpvOp opcode) { - switch (opcode) { - case SpvOpImageSampleProjImplicitLod: - case SpvOpImageSampleProjDrefImplicitLod: - case SpvOpImageSparseSampleProjImplicitLod: - case SpvOpImageSparseSampleProjDrefImplicitLod: - case SpvOpImageSampleProjExplicitLod: - case SpvOpImageSampleProjDrefExplicitLod: - case SpvOpImageSparseSampleProjExplicitLod: - case SpvOpImageSparseSampleProjDrefExplicitLod: - return true; - default: - break; - }; - return false; -} - -// Returns the number of components in a coordinate used to access a texel in -// a single plane of an image with the given parameters. -uint32_t GetPlaneCoordSize(const ImageTypeInfo& info) { - uint32_t plane_size = 0; - // If this switch breaks your build, please add new values below. - switch (info.dim) { - case SpvDim1D: - case SpvDimBuffer: - plane_size = 1; - break; - case SpvDim2D: - case SpvDimRect: - case SpvDimSubpassData: - plane_size = 2; - break; - case SpvDim3D: - case SpvDimCube: - // For Cube direction vector is used instead of UV. - plane_size = 3; - break; - case SpvDimMax: - assert(0); - break; - } - - return plane_size; -} - -// Returns minimal number of coordinates based on image dim, arrayed and whether -// the instruction uses projection coordinates. -uint32_t GetMinCoordSize(SpvOp opcode, const ImageTypeInfo& info) { - if (info.dim == SpvDimCube && - (opcode == SpvOpImageRead || opcode == SpvOpImageWrite || - opcode == SpvOpImageSparseRead)) { - // These opcodes use UV for Cube, not direction vector. - return 3; - } - - return GetPlaneCoordSize(info) + info.arrayed + (IsProj(opcode) ? 1 : 0); -} - -// Checks ImageOperand bitfield and respective operands. -spv_result_t ValidateImageOperands(ValidationState_t& _, - const spv_parsed_instruction_t& inst, - const ImageTypeInfo& info, uint32_t mask, - uint32_t word_index) { - static const bool kAllImageOperandsHandled = CheckAllImageOperandsHandled(); - (void)kAllImageOperandsHandled; - - const SpvOp opcode = static_cast(inst.opcode); - const uint32_t num_words = inst.num_words; - - size_t expected_num_image_operand_words = spvutils::CountSetBits(mask); - if (mask & SpvImageOperandsGradMask) { - // Grad uses two words. - ++expected_num_image_operand_words; - } - - if (expected_num_image_operand_words != num_words - word_index) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Number of image operand ids doesn't correspond to the bit mask: " - << spvOpcodeString(opcode); - } - - if (spvutils::CountSetBits(mask & (SpvImageOperandsOffsetMask | - SpvImageOperandsConstOffsetMask | - SpvImageOperandsConstOffsetsMask)) > 1) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operands Offset, ConstOffset, ConstOffsets cannot be used " - << "together: " << spvOpcodeString(opcode); - }; - - const bool is_implicit_lod = IsImplicitLod(opcode); - const bool is_explicit_lod = IsExplicitLod(opcode); - - // The checks should be done in the order of definition of OperandImage. - - if (mask & SpvImageOperandsBiasMask) { - if (!is_implicit_lod) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand Bias can only be used with ImplicitLod opcodes: " - << spvOpcodeString(opcode); - }; - - const uint32_t type_id = _.GetTypeId(inst.words[word_index++]); - if (!_.IsFloatScalarType(type_id)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand Bias to be float scalar: " - << spvOpcodeString(opcode); - } - - if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && - info.dim != SpvDimCube) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand Bias requires 'Dim' parameter to be 1D, 2D, 3D " - "or " - << "Cube: " << spvOpcodeString(opcode); - } - - if (info.multisampled != 0) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand Bias requires 'MS' parameter to be 0: " - << spvOpcodeString(opcode); - } - } - - if (mask & SpvImageOperandsLodMask) { - if (!is_explicit_lod && opcode != SpvOpImageFetch && - opcode != SpvOpImageSparseFetch) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand Lod can only be used with ExplicitLod opcodes " - << "and OpImageFetch: " << spvOpcodeString(opcode); - }; - - if (mask & SpvImageOperandsGradMask) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand bits Lod and Grad cannot be set at the same " - "time: " - << spvOpcodeString(opcode); - } - - const uint32_t type_id = _.GetTypeId(inst.words[word_index++]); - if (is_explicit_lod) { - if (!_.IsFloatScalarType(type_id)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand Lod to be float scalar when used " - << "with ExplicitLod: " << spvOpcodeString(opcode); - } - } else { - if (!_.IsIntScalarType(type_id)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand Lod to be int scalar when used with " - << "OpImageFetch"; - } - } - - if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && - info.dim != SpvDimCube) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand Lod requires 'Dim' parameter to be 1D, 2D, 3D " - "or " - << "Cube: " << spvOpcodeString(opcode); - } - - if (info.multisampled != 0) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand Lod requires 'MS' parameter to be 0: " - << spvOpcodeString(opcode); - } - } - - if (mask & SpvImageOperandsGradMask) { - if (!is_explicit_lod) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand Grad can only be used with ExplicitLod opcodes: " - << spvOpcodeString(opcode); - }; - - const uint32_t dx_type_id = _.GetTypeId(inst.words[word_index++]); - const uint32_t dy_type_id = _.GetTypeId(inst.words[word_index++]); - if (!_.IsFloatScalarOrVectorType(dx_type_id) || - !_.IsFloatScalarOrVectorType(dy_type_id)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected both Image Operand Grad ids to be float scalars or " - << "vectors: " << spvOpcodeString(opcode); - } - - const uint32_t plane_size = GetPlaneCoordSize(info); - const uint32_t dx_size = _.GetDimension(dx_type_id); - const uint32_t dy_size = _.GetDimension(dy_type_id); - if (plane_size != dx_size) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand Grad dx to have " << plane_size - << " components, but given " << dx_size << ": " - << spvOpcodeString(opcode); - } - - if (plane_size != dy_size) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand Grad dy to have " << plane_size - << " components, but given " << dy_size << ": " - << spvOpcodeString(opcode); - } - - if (info.multisampled != 0) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand Grad requires 'MS' parameter to be 0: " - << spvOpcodeString(opcode); - } - } - - if (mask & SpvImageOperandsConstOffsetMask) { - if (info.dim == SpvDimCube) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand ConstOffset cannot be used with Cube Image " - "'Dim': " - << spvOpcodeString(opcode); - } - - const uint32_t id = inst.words[word_index++]; - const uint32_t type_id = _.GetTypeId(id); - if (!_.IsIntScalarOrVectorType(type_id)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand ConstOffset to be int scalar or " - << "vector: " << spvOpcodeString(opcode); - } - - if (!spvOpcodeIsConstant(_.GetIdOpcode(id))) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand ConstOffset to be a const object: " - << spvOpcodeString(opcode); - } - - const uint32_t plane_size = GetPlaneCoordSize(info); - const uint32_t offset_size = _.GetDimension(type_id); - if (plane_size != offset_size) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand ConstOffset to have " << plane_size - << " components, but given " << offset_size << ": " - << spvOpcodeString(opcode); - } - } - - if (mask & SpvImageOperandsOffsetMask) { - if (info.dim == SpvDimCube) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand Offset cannot be used with Cube Image 'Dim': " - << spvOpcodeString(opcode); - } - - const uint32_t id = inst.words[word_index++]; - const uint32_t type_id = _.GetTypeId(id); - if (!_.IsIntScalarOrVectorType(type_id)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand Offset to be int scalar or " - << "vector: " << spvOpcodeString(opcode); - } - - const uint32_t plane_size = GetPlaneCoordSize(info); - const uint32_t offset_size = _.GetDimension(type_id); - if (plane_size != offset_size) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand Offset to have " << plane_size - << " components, but given " << offset_size << ": " - << spvOpcodeString(opcode); - } - } - - if (mask & SpvImageOperandsConstOffsetsMask) { - if (opcode != SpvOpImageGather && opcode != SpvOpImageDrefGather && - opcode != SpvOpImageSparseGather && - opcode != SpvOpImageSparseDrefGather) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand ConstOffsets can only be used with " - "OpImageGather " - << "and OpImageDrefGather: " << spvOpcodeString(opcode); - } - - if (info.dim == SpvDimCube) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand ConstOffsets cannot be used with Cube Image " - "'Dim': " - << spvOpcodeString(opcode); - } - - const uint32_t id = inst.words[word_index++]; - const uint32_t type_id = _.GetTypeId(id); - const Instruction* type_inst = _.FindDef(type_id); - assert(type_inst); - - if (type_inst->opcode() != SpvOpTypeArray) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand ConstOffsets to be an array of size 4: " - << spvOpcodeString(opcode); - } - - uint64_t array_size = 0; - if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) { - assert(0 && "Array type definition is corrupt"); - } - - if (array_size != 4) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand ConstOffsets to be an array of size 4: " - << spvOpcodeString(opcode); - } - - const uint32_t component_type = type_inst->word(2); - if (!_.IsIntVectorType(component_type) || - _.GetDimension(component_type) != 2) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand ConstOffsets array componenets to be " - "int " - << "vectors of size 2: " << spvOpcodeString(opcode); - } - - if (!spvOpcodeIsConstant(_.GetIdOpcode(id))) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand ConstOffsets to be a const object: " - << spvOpcodeString(opcode); - } - } - - if (mask & SpvImageOperandsSampleMask) { - if (opcode != SpvOpImageFetch && opcode != SpvOpImageRead && - opcode != SpvOpImageWrite && opcode != SpvOpImageSparseFetch && - opcode != SpvOpImageSparseRead) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand Sample can only be used with OpImageFetch, " - << "OpImageRead, OpImageWrite, OpImageSparseFetch and " - << "OpImageSparseRead: " << spvOpcodeString(opcode); - } - - if (info.multisampled == 0) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand Sample requires non-zero 'MS' parameter: " - << spvOpcodeString(opcode); - } - - const uint32_t type_id = _.GetTypeId(inst.words[word_index++]); - if (!_.IsIntScalarType(type_id)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand Sample to be int scalar: " - << spvOpcodeString(opcode); - } - } - - if (mask & SpvImageOperandsMinLodMask) { - if (!is_implicit_lod && !(mask & SpvImageOperandsGradMask)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand MinLod can only be used with ImplicitLod " - << "opcodes or together with Image Operand Grad: " - << spvOpcodeString(opcode); - }; - - const uint32_t type_id = _.GetTypeId(inst.words[word_index++]); - if (!_.IsFloatScalarType(type_id)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image Operand MinLod to be float scalar: " - << spvOpcodeString(opcode); - } - - if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && - info.dim != SpvDimCube) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand MinLod requires 'Dim' parameter to be 1D, 2D, " - "3D " - << "or Cube: " << spvOpcodeString(opcode); - } - - if (info.multisampled != 0) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Operand MinLod requires 'MS' parameter to be 0: " - << spvOpcodeString(opcode); - } - } - - return SPV_SUCCESS; -} - -// Checks some of the validation rules which are common to multiple opcodes. -spv_result_t ValidateImageCommon(ValidationState_t& _, - const spv_parsed_instruction_t& inst, - const ImageTypeInfo& info) { - const SpvOp opcode = static_cast(inst.opcode); - if (IsProj(opcode)) { - if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && - info.dim != SpvDimRect) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image 'Dim' parameter to be 1D, 2D, 3D or Rect: " - << spvOpcodeString(opcode); - } - - if (info.multisampled != 0) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Image 'MS' parameter to be 0: " - << spvOpcodeString(opcode); - } - - if (info.arrayed != 0) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Image 'arrayed' parameter to be 0: " - << spvOpcodeString(opcode); - } - } - - if (opcode == SpvOpImageRead || opcode == SpvOpImageSparseRead || - opcode == SpvOpImageWrite) { - if (info.sampled == 0) { - } else if (info.sampled == 2) { - if (info.dim == SpvDim1D && !_.HasCapability(SpvCapabilityImage1D)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Capability Image1D is required to access storage image: " - << spvOpcodeString(opcode); - } else if (info.dim == SpvDimRect && - !_.HasCapability(SpvCapabilityImageRect)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Capability ImageRect is required to access storage image: " - << spvOpcodeString(opcode); - } else if (info.dim == SpvDimBuffer && - !_.HasCapability(SpvCapabilityImageBuffer)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Capability ImageBuffer is required to access storage image: " - << spvOpcodeString(opcode); - } else if (info.dim == SpvDimCube && info.arrayed == 1 && - !_.HasCapability(SpvCapabilityImageCubeArray)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Capability ImageCubeArray is required to access storage " - << "image: " << spvOpcodeString(opcode); - } - - if (info.multisampled == 1 && - !_.HasCapability(SpvCapabilityImageMSArray)) { -#if 0 - // TODO(atgoo@github.com) The description of this rule in the spec - // is unclear and Glslang doesn't declare ImageMSArray. Need to clarify - // and reenable. - return _.diag(SPV_ERROR_INVALID_DATA) - << "Capability ImageMSArray is required to access storage " - << "image: " << spvOpcodeString(opcode); -#endif - } - } else { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image 'Sampled' parameter to be 0 or 2: " - << spvOpcodeString(opcode); - } - } - - return SPV_SUCCESS; -} - -// Returns true if opcode is *ImageSparse*, false otherwise. -bool IsSparse(SpvOp opcode) { - switch (opcode) { - case SpvOpImageSparseSampleImplicitLod: - case SpvOpImageSparseSampleExplicitLod: - case SpvOpImageSparseSampleDrefImplicitLod: - case SpvOpImageSparseSampleDrefExplicitLod: - case SpvOpImageSparseSampleProjImplicitLod: - case SpvOpImageSparseSampleProjExplicitLod: - case SpvOpImageSparseSampleProjDrefImplicitLod: - case SpvOpImageSparseSampleProjDrefExplicitLod: - case SpvOpImageSparseFetch: - case SpvOpImageSparseGather: - case SpvOpImageSparseDrefGather: - case SpvOpImageSparseTexelsResident: - case SpvOpImageSparseRead: { - return true; - } - - default: { return false; } - } - - return false; -} - -// Checks sparse image opcode result type and returns the second struct member. -// Returns inst.type_id for non-sparse image opcodes. -// Not valid for sparse image opcodes which do not return a struct. -spv_result_t GetActualResultType(ValidationState_t& _, - const spv_parsed_instruction_t& inst, - uint32_t* actual_result_type) { - const SpvOp opcode = static_cast(inst.opcode); - - if (IsSparse(opcode)) { - const Instruction* const type_inst = _.FindDef(inst.type_id); - assert(type_inst); - - if (!type_inst || type_inst->opcode() != SpvOpTypeStruct) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Result Type to be OpTypeStruct"; - } - - if (type_inst->words().size() != 4 || - !_.IsIntScalarType(type_inst->word(2))) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Result Type to be a struct containing an int " - "scalar " - << "and a texel"; - } - - *actual_result_type = type_inst->word(3); - } else { - *actual_result_type = inst.type_id; - } - - return SPV_SUCCESS; -} - -// Returns a string describing actual result type of an opcode. -// Not valid for sparse image opcodes which do not return a struct. -const char* GetActualResultTypeStr(SpvOp opcode) { - if (IsSparse(opcode)) return "Result Type's second member"; - return "Result Type"; -} - -} // namespace - -// Validates correctness of image instructions. -spv_result_t ImagePass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - const SpvOp opcode = static_cast(inst->opcode); - const uint32_t result_type = inst->type_id; - - if (IsImplicitLod(opcode)) { - _.current_function().RegisterExecutionModelLimitation( - SpvExecutionModelFragment, - "ImplicitLod instructions require Fragment execution model"); - } - - switch (opcode) { - case SpvOpTypeImage: { - assert(result_type == 0); - - ImageTypeInfo info; - if (!GetImageTypeInfo(_, inst->words[1], &info)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "OpTypeImage: corrupt definition"; - } - - if (spvIsVulkanEnv(_.context()->target_env)) { - if ((!_.IsFloatScalarType(info.sampled_type) && - !_.IsIntScalarType(info.sampled_type)) || - 32 != _.GetBitWidth(info.sampled_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Sampled Type to be a 32-bit int or float " - "scalar type for Vulkan environment"; - } - } else { - const SpvOp sampled_type_opcode = _.GetIdOpcode(info.sampled_type); - if (sampled_type_opcode != SpvOpTypeVoid && - sampled_type_opcode != SpvOpTypeInt && - sampled_type_opcode != SpvOpTypeFloat) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Sampled Type to be either void or numerical " - << "scalar type"; - } - } - - // Dim is checked elsewhere. - - if (info.depth > 2) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) << ": invalid Depth " << info.depth - << " (must be 0, 1 or 2)"; - } - - if (info.arrayed > 1) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) << ": invalid Arrayed " - << info.arrayed << " (must be 0 or 1)"; - } - - if (info.multisampled > 1) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) << ": invalid MS " - << info.multisampled << " (must be 0 or 1)"; - } - - if (info.sampled > 2) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) << ": invalid Sampled " - << info.sampled << " (must be 0, 1 or 2)"; - } - - if (info.dim == SpvDimSubpassData) { - if (info.sampled != 2) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": Dim SubpassData requires Sampled to be 2"; - } - - if (info.format != SpvImageFormatUnknown) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": Dim SubpassData requires format Unknown"; - } - } - - // Format and Access Qualifier are checked elsewhere. - - break; - } - - case SpvOpTypeSampledImage: { - const uint32_t image_type = inst->words[2]; - if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Image to be of type OpTypeImage"; - } - - break; - } - - case SpvOpSampledImage: { - if (_.GetIdOpcode(result_type) != SpvOpTypeSampledImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Result Type to be OpTypeSampledImage: " - << spvOpcodeString(opcode); - } - - const uint32_t image_type = _.GetOperandTypeId(inst, 2); - if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image to be of type OpTypeImage: " - << spvOpcodeString(opcode); - } - - ImageTypeInfo info; - if (!GetImageTypeInfo(_, image_type, &info)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Corrupt image type definition"; - } - - // TODO(atgoo@github.com) Check compatibility of result type and received - // image. - - if (spvIsVulkanEnv(_.context()->target_env)) { - if (info.sampled != 1) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image 'Sampled' parameter to be 1 for Vulkan " - "environment: " - << spvOpcodeString(opcode); - } - } else { - if (info.sampled != 0 && info.sampled != 1) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image 'Sampled' parameter to be 0 or 1: " - << spvOpcodeString(opcode); - } - } - - if (info.dim == SpvDimSubpassData) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image 'Dim' parameter to be not SubpassData: " - << spvOpcodeString(opcode); - } - - if (_.GetIdOpcode(_.GetOperandTypeId(inst, 3)) != SpvOpTypeSampler) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Sampler to be of type OpTypeSampler: " - << spvOpcodeString(opcode); - } - - break; - } - - case SpvOpImageSampleImplicitLod: - case SpvOpImageSampleExplicitLod: - case SpvOpImageSampleProjImplicitLod: - case SpvOpImageSampleProjExplicitLod: - case SpvOpImageSparseSampleImplicitLod: - case SpvOpImageSparseSampleExplicitLod: { - uint32_t actual_result_type = 0; - if (spv_result_t error = - GetActualResultType(_, *inst, &actual_result_type)) { - return error; - } - - if (!_.IsIntVectorType(actual_result_type) && - !_.IsFloatVectorType(actual_result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected " << GetActualResultTypeStr(opcode) - << " to be int or float vector type: " - << spvOpcodeString(opcode); - } - - if (_.GetDimension(actual_result_type) != 4) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected " << GetActualResultTypeStr(opcode) - << " to have 4 components: " << spvOpcodeString(opcode); - } - - const uint32_t image_type = _.GetOperandTypeId(inst, 2); - if (_.GetIdOpcode(image_type) != SpvOpTypeSampledImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Sampled Image to be of type OpTypeSampledImage: " - << spvOpcodeString(opcode); - } - - ImageTypeInfo info; - if (!GetImageTypeInfo(_, image_type, &info)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Corrupt image type definition"; - } - - if (spv_result_t result = ValidateImageCommon(_, *inst, info)) - return result; - - if (_.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { - const uint32_t texel_component_type = - _.GetComponentType(actual_result_type); - if (texel_component_type != info.sampled_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image 'Sampled Type' to be the same as " - << GetActualResultTypeStr(opcode) - << " components: " << spvOpcodeString(opcode); - } - } - - const uint32_t coord_type = _.GetOperandTypeId(inst, 3); - if ((opcode == SpvOpImageSampleExplicitLod || - opcode == SpvOpImageSparseSampleExplicitLod) && - _.HasCapability(SpvCapabilityKernel)) { - if (!_.IsFloatScalarOrVectorType(coord_type) && - !_.IsIntScalarOrVectorType(coord_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to be int or float scalar or vector: " - << spvOpcodeString(opcode); - } - } else { - if (!_.IsFloatScalarOrVectorType(coord_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to be float scalar or vector: " - << spvOpcodeString(opcode); - } - } - - const uint32_t min_coord_size = GetMinCoordSize(opcode, info); - const uint32_t actual_coord_size = _.GetDimension(coord_type); - if (min_coord_size > actual_coord_size) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to have at least " << min_coord_size - << " components, but given only " << actual_coord_size << ": " - << spvOpcodeString(opcode); - } - - if (inst->num_words <= 5) { - assert(IsImplicitLod(opcode)); - break; - } - - const uint32_t mask = inst->words[5]; - if (spv_result_t result = - ValidateImageOperands(_, *inst, info, mask, /* word_index = */ 6)) - return result; - - break; - } - - case SpvOpImageSampleDrefImplicitLod: - case SpvOpImageSampleDrefExplicitLod: - case SpvOpImageSampleProjDrefImplicitLod: - case SpvOpImageSampleProjDrefExplicitLod: - case SpvOpImageSparseSampleDrefImplicitLod: - case SpvOpImageSparseSampleDrefExplicitLod: { - uint32_t actual_result_type = 0; - if (spv_result_t error = - GetActualResultType(_, *inst, &actual_result_type)) { - return error; - } - - if (!_.IsIntScalarType(actual_result_type) && - !_.IsFloatScalarType(actual_result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected " << GetActualResultTypeStr(opcode) - << " to be int or float scalar type: " - << spvOpcodeString(opcode); - } - - const uint32_t image_type = _.GetOperandTypeId(inst, 2); - if (_.GetIdOpcode(image_type) != SpvOpTypeSampledImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Sampled Image to be of type OpTypeSampledImage: " - << spvOpcodeString(opcode); - } - - ImageTypeInfo info; - if (!GetImageTypeInfo(_, image_type, &info)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Corrupt image type definition"; - } - - if (spv_result_t result = ValidateImageCommon(_, *inst, info)) - return result; - - if (actual_result_type != info.sampled_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image 'Sampled Type' to be the same as " - << GetActualResultTypeStr(opcode) << ": " - << spvOpcodeString(opcode); - } - - const uint32_t coord_type = _.GetOperandTypeId(inst, 3); - if (!_.IsFloatScalarOrVectorType(coord_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to be float scalar or vector: " - << spvOpcodeString(opcode); - } - - const uint32_t min_coord_size = GetMinCoordSize(opcode, info); - const uint32_t actual_coord_size = _.GetDimension(coord_type); - if (min_coord_size > actual_coord_size) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to have at least " << min_coord_size - << " components, but given only " << actual_coord_size << ": " - << spvOpcodeString(opcode); - } - - const uint32_t dref_type = _.GetOperandTypeId(inst, 4); - if (!_.IsFloatScalarType(dref_type) || _.GetBitWidth(dref_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": Expected Dref to be of 32-bit float type"; - } - - if (inst->num_words <= 6) { - assert(IsImplicitLod(opcode)); - break; - } - - const uint32_t mask = inst->words[6]; - if (spv_result_t result = - ValidateImageOperands(_, *inst, info, mask, /* word_index = */ 7)) - return result; - - break; - } - - case SpvOpImageFetch: - case SpvOpImageSparseFetch: { - uint32_t actual_result_type = 0; - if (spv_result_t error = - GetActualResultType(_, *inst, &actual_result_type)) { - return error; - } - - if (!_.IsIntVectorType(actual_result_type) && - !_.IsFloatVectorType(actual_result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected " << GetActualResultTypeStr(opcode) - << " to be int or float vector type: " - << spvOpcodeString(opcode); - } - - if (_.GetDimension(actual_result_type) != 4) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected " << GetActualResultTypeStr(opcode) - << " to have 4 components: " << spvOpcodeString(opcode); - } - - const uint32_t image_type = _.GetOperandTypeId(inst, 2); - if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image to be of type OpTypeImage: " - << spvOpcodeString(opcode); - } - - ImageTypeInfo info; - if (!GetImageTypeInfo(_, image_type, &info)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Corrupt image type definition"; - } - - if (_.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { - const uint32_t result_component_type = - _.GetComponentType(actual_result_type); - if (result_component_type != info.sampled_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image 'Sampled Type' to be the same as " - << GetActualResultTypeStr(opcode) - << " components: " << spvOpcodeString(opcode); - } - } - - if (info.dim == SpvDimCube) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image 'Dim' cannot be Cube: " << spvOpcodeString(opcode); - } - - if (info.sampled != 1) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image 'Sampled' parameter to be 1: " - << spvOpcodeString(opcode); - } - - const uint32_t coord_type = _.GetOperandTypeId(inst, 3); - if (!_.IsIntScalarOrVectorType(coord_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to be int scalar or vector: " - << spvOpcodeString(opcode); - } - - const uint32_t min_coord_size = GetMinCoordSize(opcode, info); - const uint32_t actual_coord_size = _.GetDimension(coord_type); - if (min_coord_size > actual_coord_size) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to have at least " << min_coord_size - << " components, but given only " << actual_coord_size << ": " - << spvOpcodeString(opcode); - } - - if (inst->num_words <= 5) break; - - const uint32_t mask = inst->words[5]; - if (spv_result_t result = - ValidateImageOperands(_, *inst, info, mask, /* word_index = */ 6)) - return result; - - break; - } - - case SpvOpImageGather: - case SpvOpImageDrefGather: - case SpvOpImageSparseGather: - case SpvOpImageSparseDrefGather: { - uint32_t actual_result_type = 0; - if (spv_result_t error = - GetActualResultType(_, *inst, &actual_result_type)) { - return error; - } - - if (!_.IsIntVectorType(actual_result_type) && - !_.IsFloatVectorType(actual_result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected " << GetActualResultTypeStr(opcode) - << " to be int or float vector type: " - << spvOpcodeString(opcode); - } - - if (_.GetDimension(actual_result_type) != 4) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected " << GetActualResultTypeStr(opcode) - << " to have 4 components: " << spvOpcodeString(opcode); - } - - const uint32_t image_type = _.GetOperandTypeId(inst, 2); - if (_.GetIdOpcode(image_type) != SpvOpTypeSampledImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Sampled Image to be of type OpTypeSampledImage: " - << spvOpcodeString(opcode); - } - - ImageTypeInfo info; - if (!GetImageTypeInfo(_, image_type, &info)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Corrupt image type definition"; - } - - if (opcode == SpvOpImageDrefGather || - opcode == SpvOpImageSparseDrefGather || - _.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { - const uint32_t result_component_type = - _.GetComponentType(actual_result_type); - if (result_component_type != info.sampled_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image 'Sampled Type' to be the same as " - << GetActualResultTypeStr(opcode) - << " components: " << spvOpcodeString(opcode); - } - } - - if (info.dim != SpvDim2D && info.dim != SpvDimCube && - info.dim != SpvDimRect) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image 'Dim' cannot be Cube: " - << spvOpcodeString(opcode); - } - - const uint32_t coord_type = _.GetOperandTypeId(inst, 3); - if (!_.IsFloatScalarOrVectorType(coord_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to be float scalar or vector: " - << spvOpcodeString(opcode); - } - - const uint32_t min_coord_size = GetMinCoordSize(opcode, info); - const uint32_t actual_coord_size = _.GetDimension(coord_type); - if (min_coord_size > actual_coord_size) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to have at least " << min_coord_size - << " components, but given only " << actual_coord_size << ": " - << spvOpcodeString(opcode); - } - - if (opcode == SpvOpImageGather || opcode == SpvOpImageSparseGather) { - const uint32_t component_index_type = _.GetOperandTypeId(inst, 4); - if (!_.IsIntScalarType(component_index_type) || - _.GetBitWidth(component_index_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Component to be 32-bit int scalar: " - << spvOpcodeString(opcode); - } - } else { - assert(opcode == SpvOpImageDrefGather || - opcode == SpvOpImageSparseDrefGather); - const uint32_t dref_type = _.GetOperandTypeId(inst, 4); - if (!_.IsFloatScalarType(dref_type) || _.GetBitWidth(dref_type) != 32) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": Expected Dref to be of 32-bit float type"; - } - } - - if (inst->num_words <= 6) break; - - const uint32_t mask = inst->words[6]; - if (spv_result_t result = - ValidateImageOperands(_, *inst, info, mask, /* word_index = */ 7)) - return result; - - break; - } - - case SpvOpImageRead: - case SpvOpImageSparseRead: { - uint32_t actual_result_type = 0; - if (spv_result_t error = - GetActualResultType(_, *inst, &actual_result_type)) { - return error; - } - - if (!_.IsIntScalarOrVectorType(actual_result_type) && - !_.IsFloatScalarOrVectorType(actual_result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected " << GetActualResultTypeStr(opcode) - << " to be int or float scalar or vector type: " - << spvOpcodeString(opcode); - } - -#if 0 - // TODO(atgoo@github.com) Disabled until the spec is clarified. - if (_.GetDimension(actual_result_type) != 4) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected " << GetActualResultTypeStr(opcode) - << " to have 4 components: " << spvOpcodeString(opcode); - } -#endif - - const uint32_t image_type = _.GetOperandTypeId(inst, 2); - if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image to be of type OpTypeImage: " - << spvOpcodeString(opcode); - } - - ImageTypeInfo info; - if (!GetImageTypeInfo(_, image_type, &info)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Corrupt image type definition"; - } - - if (info.dim == SpvDimSubpassData) { - if (opcode == SpvOpImageSparseRead) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image Dim SubpassData cannot be used with " - << spvOpcodeString(opcode); - } - - _.current_function().RegisterExecutionModelLimitation( - SpvExecutionModelFragment, - std::string("Dim SubpassData requires Fragment execution model: ") + - spvOpcodeString(opcode)); - } - - if (_.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { - const uint32_t result_component_type = - _.GetComponentType(actual_result_type); - if (result_component_type != info.sampled_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image 'Sampled Type' to be the same as " - << GetActualResultTypeStr(opcode) - << " components: " << spvOpcodeString(opcode); - } - } - - if (spv_result_t result = ValidateImageCommon(_, *inst, info)) - return result; - - const uint32_t coord_type = _.GetOperandTypeId(inst, 3); - if (!_.IsIntScalarOrVectorType(coord_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to be int scalar or vector: " - << spvOpcodeString(opcode); - } - - const uint32_t min_coord_size = GetMinCoordSize(opcode, info); - const uint32_t actual_coord_size = _.GetDimension(coord_type); - if (min_coord_size > actual_coord_size) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to have at least " << min_coord_size - << " components, but given only " << actual_coord_size << ": " - << spvOpcodeString(opcode); - } - - if (info.format == SpvImageFormatUnknown && - info.dim != SpvDimSubpassData && - !_.HasCapability(SpvCapabilityStorageImageReadWithoutFormat)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Capability StorageImageReadWithoutFormat is required to " - << "read storage image: " << spvOpcodeString(opcode); - } - - if (inst->num_words <= 5) break; - - const uint32_t mask = inst->words[5]; - if (spv_result_t result = - ValidateImageOperands(_, *inst, info, mask, /* word_index = */ 6)) - return result; - - break; - } - - case SpvOpImageWrite: { - const uint32_t image_type = _.GetOperandTypeId(inst, 0); - if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image to be of type OpTypeImage: " - << spvOpcodeString(opcode); - } - - ImageTypeInfo info; - if (!GetImageTypeInfo(_, image_type, &info)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Corrupt image type definition"; - } - - if (info.dim == SpvDimSubpassData) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image 'Dim' cannot be SubpassData: " - << spvOpcodeString(opcode); - } - - if (spv_result_t result = ValidateImageCommon(_, *inst, info)) - return result; - - const uint32_t coord_type = _.GetOperandTypeId(inst, 1); - if (!_.IsIntScalarOrVectorType(coord_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to be int scalar or vector: " - << spvOpcodeString(opcode); - } - - const uint32_t min_coord_size = GetMinCoordSize(opcode, info); - const uint32_t actual_coord_size = _.GetDimension(coord_type); - if (min_coord_size > actual_coord_size) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to have at least " << min_coord_size - << " components, but given only " << actual_coord_size << ": " - << spvOpcodeString(opcode); - } - - // TODO(atgoo@github.com) The spec doesn't explicitely say what the type - // of texel should be. - const uint32_t texel_type = _.GetOperandTypeId(inst, 2); - if (!_.IsIntScalarOrVectorType(texel_type) && - !_.IsFloatScalarOrVectorType(texel_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Texel to be int or float vector or scalar: " - << spvOpcodeString(opcode); - } - -#if 0 - // TODO: See above. - if (_.GetDimension(texel_type) != 4) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Texel to have 4 components: " - << spvOpcodeString(opcode); - } -#endif - - if (_.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { - const uint32_t texel_component_type = _.GetComponentType(texel_type); - if (texel_component_type != info.sampled_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image 'Sampled Type' to be the same as Texel " - << "components: " << spvOpcodeString(opcode); - } - } - - if (info.format == SpvImageFormatUnknown && - info.dim != SpvDimSubpassData && - !_.HasCapability(SpvCapabilityStorageImageWriteWithoutFormat)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Capability StorageImageWriteWithoutFormat is required to " - "write " - << "to storage image: " << spvOpcodeString(opcode); - } - - if (inst->num_words <= 4) break; - - const uint32_t mask = inst->words[4]; - if (spv_result_t result = - ValidateImageOperands(_, *inst, info, mask, /* word_index = */ 5)) - return result; - - break; - } - - case SpvOpImage: { - if (_.GetIdOpcode(result_type) != SpvOpTypeImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Result Type to be OpTypeImage: " - << spvOpcodeString(opcode); - } - - const uint32_t sampled_image_type = _.GetOperandTypeId(inst, 2); - const Instruction* sampled_image_type_inst = - _.FindDef(sampled_image_type); - assert(sampled_image_type_inst); - - if (sampled_image_type_inst->opcode() != SpvOpTypeSampledImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Sample Image to be of type OpTypeSampleImage: " - << spvOpcodeString(opcode); - } - - if (sampled_image_type_inst->word(2) != result_type) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Sample Image image type to be equal to Result " - "Type: " - << spvOpcodeString(opcode); - } - - break; - } - - case SpvOpImageQueryFormat: - case SpvOpImageQueryOrder: { - if (!_.IsIntScalarType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Result Type to be int scalar type: " - << spvOpcodeString(opcode); - } - - if (_.GetIdOpcode(_.GetOperandTypeId(inst, 2)) != SpvOpTypeImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected operand to be of type OpTypeImage: " - << spvOpcodeString(opcode); - } - break; - } - - case SpvOpImageQuerySizeLod: { - if (!_.IsIntScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Result Type to be int scalar or vector type: " - << spvOpcodeString(opcode); - } - - const uint32_t image_type = _.GetOperandTypeId(inst, 2); - if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image to be of type OpTypeImage: " - << spvOpcodeString(opcode); - } - - ImageTypeInfo info; - if (!GetImageTypeInfo(_, image_type, &info)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Corrupt image type definition"; - } - - uint32_t expected_num_components = info.arrayed; - switch (info.dim) { - case SpvDim1D: - expected_num_components += 1; - break; - case SpvDim2D: - case SpvDimCube: - expected_num_components += 2; - break; - case SpvDim3D: - expected_num_components += 3; - break; - default: - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image 'Dim' must be 1D, 2D, 3D or Cube: " - << spvOpcodeString(opcode); - }; - - if (info.multisampled != 0) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image 'MS' must be 0: " << spvOpcodeString(opcode); - } - - uint32_t result_num_components = _.GetDimension(result_type); - if (result_num_components != expected_num_components) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Result Type has " << result_num_components << " components, " - << "but " << expected_num_components - << " expected: " << spvOpcodeString(opcode); - } - - const uint32_t lod_type = _.GetOperandTypeId(inst, 3); - if (!_.IsIntScalarType(lod_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Level of Detail to be int scalar: " - << spvOpcodeString(opcode); - } - - break; - } - - case SpvOpImageQuerySize: { - if (!_.IsIntScalarOrVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Result Type to be int scalar or vector type: " - << spvOpcodeString(opcode); - } - - const uint32_t image_type = _.GetOperandTypeId(inst, 2); - if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image to be of type OpTypeImage: " - << spvOpcodeString(opcode); - } - -#if 0 - // TODO(atgoo@github.com) The spec doesn't whitelist all Dims supported by - // GLSL. Need to verify if there is an error and reenable. - ImageTypeInfo info; - if (!GetImageTypeInfo(_, image_type, &info)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Corrupt image type definition"; - } - - uint32_t expected_num_components = info.arrayed; - switch (info.dim) { - case SpvDimBuffer: - expected_num_components += 1; - break; - case SpvDim2D: - if (info.multisampled != 1 && info.sampled != 0 && - info.sampled != 2) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected either 'MS'=1 or 'Sampled'=0 or 'Sampled'=2 " - << "for 2D dim: " << spvOpcodeString(opcode); - } - case SpvDimRect: - expected_num_components += 2; - break; - case SpvDim3D: - expected_num_components += 3; - if (info.sampled != 0 && - info.sampled != 2) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected either 'Sampled'=0 or 'Sampled'=2 " - << "for 3D dim: " << spvOpcodeString(opcode); - } - break; - default: - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image 'Dim' must be Buffer, 2D, 3D or Rect: " - << spvOpcodeString(opcode); - }; - - - if (info.multisampled != 0) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image 'MS' must be 0: " << spvOpcodeString(opcode); - } - - uint32_t result_num_components = _.GetDimension(result_type); - if (result_num_components != expected_num_components) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Result Type has " << result_num_components << " components, " - << "but " << expected_num_components << " expected: " - << spvOpcodeString(opcode); - } -#endif - break; - } - - case SpvOpImageQueryLod: { - _.current_function().RegisterExecutionModelLimitation( - SpvExecutionModelFragment, - "OpImageQueryLod requires Fragment execution model"); - - if (!_.IsFloatVectorType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Result Type to be float vector type: " - << spvOpcodeString(opcode); - } - - if (_.GetDimension(result_type) != 2) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Result Type to have 2 components: " - << spvOpcodeString(opcode); - } - - const uint32_t image_type = _.GetOperandTypeId(inst, 2); - if (_.GetIdOpcode(image_type) != SpvOpTypeSampledImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image operand to be of type OpTypeSampledImage: " - << spvOpcodeString(opcode); - } - - ImageTypeInfo info; - if (!GetImageTypeInfo(_, image_type, &info)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Corrupt image type definition"; - } - - if (info.dim != SpvDim1D && info.dim != SpvDim2D && - info.dim != SpvDim3D && info.dim != SpvDimCube) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image 'Dim' must be 1D, 2D, 3D or Cube: " - << spvOpcodeString(opcode); - } - - const uint32_t coord_type = _.GetOperandTypeId(inst, 3); - if (_.HasCapability(SpvCapabilityKernel)) { - if (!_.IsFloatScalarOrVectorType(coord_type) && - !_.IsIntScalarOrVectorType(coord_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to be int or float scalar or vector: " - << spvOpcodeString(opcode); - } - } else { - if (!_.IsFloatScalarOrVectorType(coord_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to be float scalar or vector: " - << spvOpcodeString(opcode); - } - } - - const uint32_t min_coord_size = GetPlaneCoordSize(info); - const uint32_t actual_coord_size = _.GetDimension(coord_type); - if (min_coord_size > actual_coord_size) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Coordinate to have at least " << min_coord_size - << " components, but given only " << actual_coord_size << ": " - << spvOpcodeString(opcode); - } - break; - } - - case SpvOpImageQueryLevels: - case SpvOpImageQuerySamples: { - if (!_.IsIntScalarType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Result Type to be int scalar type: " - << spvOpcodeString(opcode); - } - - const uint32_t image_type = _.GetOperandTypeId(inst, 2); - if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Expected Image to be of type OpTypeImage: " - << spvOpcodeString(opcode); - } - - ImageTypeInfo info; - if (!GetImageTypeInfo(_, image_type, &info)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Corrupt image type definition"; - } - - if (opcode == SpvOpImageQueryLevels) { - if (info.dim != SpvDim1D && info.dim != SpvDim2D && - info.dim != SpvDim3D && info.dim != SpvDimCube) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image 'Dim' must be 1D, 2D, 3D or Cube: " - << spvOpcodeString(opcode); - } - } else { - assert(opcode == SpvOpImageQuerySamples); - if (info.dim != SpvDim2D) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image 'Dim' must be 2D: " << spvOpcodeString(opcode); - } - - if (info.multisampled != 1) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Image 'MS' must be 1: " << spvOpcodeString(opcode); - } - } - - break; - } - - case SpvOpImageSparseSampleProjImplicitLod: - case SpvOpImageSparseSampleProjExplicitLod: - case SpvOpImageSparseSampleProjDrefImplicitLod: - case SpvOpImageSparseSampleProjDrefExplicitLod: { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": instruction reserved for future use, " - << "use of this instruction is invalid"; - } - - case SpvOpImageSparseTexelsResident: { - if (!_.IsBoolScalarType(result_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Result Type to be bool scalar type"; - } - - const uint32_t resident_code_type = _.GetOperandTypeId(inst, 2); - if (!_.IsIntScalarType(resident_code_type)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << spvOpcodeString(opcode) - << ": expected Resident Code to be int scalar"; - } - break; - } - - default: - break; - } - - return SPV_SUCCESS; -} - -} // namespace libspirv diff --git a/3rdparty/spirv-tools/source/validate_type_unique.cpp b/3rdparty/spirv-tools/source/validate_type_unique.cpp deleted file mode 100644 index b7f77ed6c..000000000 --- a/3rdparty/spirv-tools/source/validate_type_unique.cpp +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) 2017 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Ensures type declarations are unique unless allowed by the specification. - -#include "validate.h" - -#include "diagnostic.h" -#include "opcode.h" -#include "val/instruction.h" -#include "val/validation_state.h" - -namespace libspirv { - -// Validates that type declarations are unique, unless multiple declarations -// of the same data type are allowed by the specification. -// (see section 2.8 Types and Variables) -// Doesn't do anything if SPV_VAL_ignore_type_decl_unique was declared in the -// module. -spv_result_t TypeUniquePass(ValidationState_t& _, - const spv_parsed_instruction_t* inst) { - if (_.HasExtension(Extension::kSPV_VALIDATOR_ignore_type_decl_unique)) - return SPV_SUCCESS; - - const SpvOp opcode = static_cast(inst->opcode); - - if (spvOpcodeGeneratesType(opcode)) { - if (opcode == SpvOpTypeArray || opcode == SpvOpTypeRuntimeArray || - opcode == SpvOpTypeStruct) { - // Duplicate declarations of aggregates are allowed. - return SPV_SUCCESS; - } - - if (inst->opcode == SpvOpTypePointer && - _.HasExtension(Extension::kSPV_KHR_variable_pointers)) { - // Duplicate pointer types are allowed with this extension. - return SPV_SUCCESS; - } - - if (!_.RegisterUniqueTypeDeclaration(*inst)) { - return _.diag(SPV_ERROR_INVALID_DATA) - << "Duplicate non-aggregate type declarations are not allowed." - << " Opcode: " << spvOpcodeString(SpvOp(inst->opcode)) - << " id: " << inst->result_id; - } - } - - return SPV_SUCCESS; -} - -} // namespace libspirv diff --git a/3rdparty/spirv-tools/test/CMakeLists.txt b/3rdparty/spirv-tools/test/CMakeLists.txt index 5317e1f1f..1fdf5a212 100644 --- a/3rdparty/spirv-tools/test/CMakeLists.txt +++ b/3rdparty/spirv-tools/test/CMakeLists.txt @@ -198,21 +198,29 @@ add_spvtools_unittest( add_spvtools_unittest( TARGET bit_stream SRCS bit_stream.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.h LIBS ${SPIRV_TOOLS}) add_spvtools_unittest( TARGET huffman_codec SRCS huffman_codec.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.h + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/huffman_codec.h LIBS ${SPIRV_TOOLS}) add_spvtools_unittest( TARGET move_to_front SRCS move_to_front_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/move_to_front.h + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/move_to_front.cpp LIBS ${SPIRV_TOOLS}) add_subdirectory(comp) add_subdirectory(link) add_subdirectory(opt) add_subdirectory(stats) +add_subdirectory(tools) add_subdirectory(util) add_subdirectory(val) diff --git a/3rdparty/spirv-tools/test/assembly_context_test.cpp b/3rdparty/spirv-tools/test/assembly_context_test.cpp index 65a40aad2..b6d60b95d 100644 --- a/3rdparty/spirv-tools/test/assembly_context_test.cpp +++ b/3rdparty/spirv-tools/test/assembly_context_test.cpp @@ -12,20 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" - -#include +#include #include +#include "gmock/gmock.h" #include "source/instruction.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { -using libspirv::AssemblyContext; using spvtest::AutoText; using spvtest::Concatenate; using ::testing::Eq; -namespace { - struct EncodeStringCase { std::string str; std::vector initial_contents; @@ -73,4 +73,5 @@ INSTANTIATE_TEST_CASE_P( }),); // clang-format on -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/assembly_format_test.cpp b/3rdparty/spirv-tools/test/assembly_format_test.cpp index 953fb8a33..59e500b81 100644 --- a/3rdparty/spirv-tools/test/assembly_format_test.cpp +++ b/3rdparty/spirv-tools/test/assembly_format_test.cpp @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "test_fixture.h" +#include "test/test_fixture.h" +namespace svptools { namespace { using spvtest::ScopedContext; @@ -32,4 +33,5 @@ TEST_F(TextToBinaryTest, NotPlacingResultIDAtTheBeginning) { EXPECT_EQ(0u, diagnostic->position.line); } -} // anonymous namespace +} // namespace +} // namespace svptools diff --git a/3rdparty/spirv-tools/test/binary_destroy_test.cpp b/3rdparty/spirv-tools/test/binary_destroy_test.cpp index 2df8379c8..e3870c9f0 100644 --- a/3rdparty/spirv-tools/test/binary_destroy_test.cpp +++ b/3rdparty/spirv-tools/test/binary_destroy_test.cpp @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include "test/unit_spirv.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +namespace spvtools { namespace { using spvtest::ScopedContext; @@ -39,4 +40,5 @@ TEST_F(BinaryDestroySomething, Default) { spvBinaryDestroy(my_binary); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/binary_endianness_test.cpp b/3rdparty/spirv-tools/test/binary_endianness_test.cpp index 343c17e2e..3cd405d52 100644 --- a/3rdparty/spirv-tools/test/binary_endianness_test.cpp +++ b/3rdparty/spirv-tools/test/binary_endianness_test.cpp @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { TEST(BinaryEndianness, InvalidCode) { @@ -49,4 +50,5 @@ TEST(BinaryEndianness, Big) { ASSERT_EQ(SPV_ENDIANNESS_BIG, endian); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/binary_header_get_test.cpp b/3rdparty/spirv-tools/test/binary_header_get_test.cpp index d6efe5af0..e771f1a39 100644 --- a/3rdparty/spirv-tools/test/binary_header_get_test.cpp +++ b/3rdparty/spirv-tools/test/binary_header_get_test.cpp @@ -13,8 +13,9 @@ // limitations under the License. #include "source/spirv_constant.h" -#include "unit_spirv.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { class BinaryHeaderGet : public ::testing::Test { @@ -79,4 +80,5 @@ TEST_F(BinaryHeaderGet, TruncatedHeader) { } } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/binary_parse_test.cpp b/3rdparty/spirv-tools/test/binary_parse_test.cpp index bb1d6ffd1..7d9700158 100644 --- a/3rdparty/spirv-tools/test/binary_parse_test.cpp +++ b/3rdparty/spirv-tools/test/binary_parse_test.cpp @@ -12,16 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include #include #include #include "gmock/gmock.h" -#include "latest_version_opencl_std_header.h" -#include "source/message.h" +#include "source/latest_version_opencl_std_header.h" #include "source/table.h" -#include "test_fixture.h" -#include "unit_spirv.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" // Returns true if two spv_parsed_operand_t values are equal. // To use this operator, this definition must appear in the same namespace @@ -33,9 +34,9 @@ static bool operator==(const spv_parsed_operand_t& a, a.number_bit_width == b.number_bit_width; } +namespace spvtools { namespace { -using ::libspirv::SetContextMessageConsumer; using ::spvtest::Concatenate; using ::spvtest::MakeInstruction; using ::spvtest::MakeVector; @@ -887,12 +888,13 @@ INSTANTIATE_TEST_CASE_P( "component 32"}, {"%2 = OpFunction %2 !31", "Invalid function control operand: 31 has invalid mask component 16"}, - {"OpLoopMerge %1 %2 !7", - "Invalid loop control operand: 7 has invalid mask component 4"}, + {"OpLoopMerge %1 %2 !1027", + "Invalid loop control operand: 1027 has invalid mask component 1024"}, {"%2 = OpImageFetch %1 %image %coord !511", "Invalid image operand: 511 has invalid mask component 256"}, {"OpSelectionMerge %1 !7", "Invalid selection control operand: 7 has invalid mask component 4"}, }), ); -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/binary_strnlen_s_test.cpp b/3rdparty/spirv-tools/test/binary_strnlen_s_test.cpp index 2d2170b09..5f43bde67 100644 --- a/3rdparty/spirv-tools/test/binary_strnlen_s_test.cpp +++ b/3rdparty/spirv-tools/test/binary_strnlen_s_test.cpp @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { TEST(Strnlen, Samples) { @@ -27,4 +28,5 @@ TEST(Strnlen, Samples) { EXPECT_EQ(1u, spv_strnlen_s("a\0c", 5)); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/binary_to_text.literal_test.cpp b/3rdparty/spirv-tools/test/binary_to_text.literal_test.cpp index 9201f4a06..bcfb0f016 100644 --- a/3rdparty/spirv-tools/test/binary_to_text.literal_test.cpp +++ b/3rdparty/spirv-tools/test/binary_to_text.literal_test.cpp @@ -12,15 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include +#include +#include #include "gmock/gmock.h" -#include "test_fixture.h" - -using ::testing::Eq; +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { +using ::testing::Eq; using RoundTripLiteralsTest = spvtest::TextToBinaryTestBase<::testing::TestWithParam>; @@ -69,4 +72,5 @@ INSTANTIATE_TEST_CASE_P( }),); // clang-format on -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/binary_to_text_test.cpp b/3rdparty/spirv-tools/test/binary_to_text_test.cpp index 1ec13db76..016041f49 100644 --- a/3rdparty/spirv-tools/test/binary_to_text_test.cpp +++ b/3rdparty/spirv-tools/test/binary_to_text_test.cpp @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" - #include +#include +#include +#include #include "gmock/gmock.h" - #include "source/spirv_constant.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::AutoText; using spvtest::ScopedContext; using spvtest::TextToBinaryTest; -using std::get; -using std::tuple; using ::testing::Combine; using ::testing::Eq; using ::testing::HasSubstr; @@ -226,13 +226,13 @@ OpExecutionMode %1 LocalSizeHint 100 200 300 } using RoundTripInstructionsTest = spvtest::TextToBinaryTestBase< - ::testing::TestWithParam>>; + ::testing::TestWithParam>>; TEST_P(RoundTripInstructionsTest, Sample) { - EXPECT_THAT(EncodeAndDecodeSuccessfully(get<1>(GetParam()), + EXPECT_THAT(EncodeAndDecodeSuccessfully(std::get<1>(GetParam()), SPV_BINARY_TO_TEXT_OPTION_NONE, - get<0>(GetParam())), - Eq(get<1>(GetParam()))); + std::get<0>(GetParam())), + Eq(std::get<1>(GetParam()))); } // clang-format off @@ -550,4 +550,5 @@ INSTANTIATE_TEST_CASE_P(GeneratorStrings, GeneratorStringTest, // TODO(dneto): Test new instructions and enums in SPIR-V 1.3 -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/bit_stream.cpp b/3rdparty/spirv-tools/test/bit_stream.cpp index d30a79f7f..f02faf3c6 100644 --- a/3rdparty/spirv-tools/test/bit_stream.cpp +++ b/3rdparty/spirv-tools/test/bit_stream.cpp @@ -12,38 +12,88 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include +#include #include #include "gmock/gmock.h" -#include "util/bit_stream.h" +#include "source/comp/bit_stream.h" +namespace spvtools { +namespace comp { namespace { -using spvutils::BitReaderInterface; -using spvutils::BitReaderWord64; -using spvutils::BitsetToStream; -using spvutils::BitsToStream; -using spvutils::BitWriterInterface; -using spvutils::BitWriterWord64; -using spvutils::BufferToStream; -using spvutils::DecodeZigZag; -using spvutils::EncodeZigZag; -using spvutils::GetLowerBits; -using spvutils::Log2U64; -using spvutils::NumBitsToNumWords; -using spvutils::PadToWord; -using spvutils::StreamToBits; -using spvutils::StreamToBitset; -using spvutils::StreamToBuffer; +// Converts |buffer| to a stream of '0' and '1'. +template +std::string BufferToStream(const std::vector& buffer) { + std::stringstream ss; + for (auto it = buffer.begin(); it != buffer.end(); ++it) { + std::string str = std::bitset(*it).to_string(); + // Strings generated by std::bitset::to_string are read right to left. + // Reversing to left to right. + std::reverse(str.begin(), str.end()); + ss << str; + } + return ss.str(); +} + +// Converts a left-to-right input string of '0' and '1' to a buffer of |T| +// words. +template +std::vector StreamToBuffer(std::string str) { + // The input string is left-to-right, the input argument of std::bitset needs + // to right-to-left. Instead of reversing tokens, reverse the entire string + // and iterate tokens from end to begin. + std::reverse(str.begin(), str.end()); + const int word_size = static_cast(sizeof(T) * 8); + const int str_length = static_cast(str.length()); + std::vector buffer; + buffer.reserve(NumBitsToNumWords(str.length())); + for (int index = str_length - word_size; index >= 0; index -= word_size) { + buffer.push_back(static_cast( + std::bitset(str, index, word_size).to_ullong())); + } + const size_t suffix_length = str.length() % word_size; + if (suffix_length != 0) { + buffer.push_back(static_cast( + std::bitset(str, 0, suffix_length).to_ullong())); + } + return buffer; +} + +// Adds '0' chars at the end of the string until the size is a multiple of N. +template +std::string PadToWord(std::string&& str) { + const size_t tail_length = str.size() % N; + if (tail_length != 0) str += std::string(N - tail_length, '0'); + return std::move(str); +} + +// Adds '0' chars at the end of the string until the size is a multiple of N. +template +std::string PadToWord(const std::string& str) { + return PadToWord(std::string(str)); +} + +// Converts a left-to-right stream of bits to std::bitset. +template +std::bitset StreamToBitset(std::string str) { + std::reverse(str.begin(), str.end()); + return std::bitset(str); +} + +// Converts a left-to-right stream of bits to uint64. +uint64_t StreamToBits(std::string str) { + std::reverse(str.begin(), str.end()); + return std::bitset<64>(str).to_ullong(); +} // A simple and inefficient implementatition of BitWriterInterface, // using std::stringstream. Intended for tests only. class BitWriterStringStream : public BitWriterInterface { public: - void WriteStream(const std::string& bits) override { ss_ << bits; } - void WriteBits(uint64_t bits, size_t num_bits) override { assert(num_bits <= 64); ss_ << BitsToStream(bits, num_bits); @@ -86,52 +136,11 @@ class BitReaderFromString : public BitReaderInterface { bool ReachedEnd() const override { return pos_ >= str_.length(); } - const std::string& GetStreamPadded64() const { return str_; } - private: std::string str_; size_t pos_; }; -TEST(Log2U16, Test) { - EXPECT_EQ(0u, Log2U64(0)); - EXPECT_EQ(0u, Log2U64(1)); - EXPECT_EQ(1u, Log2U64(2)); - EXPECT_EQ(1u, Log2U64(3)); - EXPECT_EQ(2u, Log2U64(4)); - EXPECT_EQ(2u, Log2U64(5)); - EXPECT_EQ(2u, Log2U64(6)); - EXPECT_EQ(2u, Log2U64(7)); - EXPECT_EQ(3u, Log2U64(8)); - EXPECT_EQ(3u, Log2U64(9)); - EXPECT_EQ(3u, Log2U64(10)); - EXPECT_EQ(3u, Log2U64(11)); - EXPECT_EQ(3u, Log2U64(12)); - EXPECT_EQ(3u, Log2U64(13)); - EXPECT_EQ(3u, Log2U64(14)); - EXPECT_EQ(3u, Log2U64(15)); - EXPECT_EQ(4u, Log2U64(16)); - EXPECT_EQ(4u, Log2U64(17)); - EXPECT_EQ(5u, Log2U64(35)); - EXPECT_EQ(6u, Log2U64(72)); - EXPECT_EQ(7u, Log2U64(255)); - EXPECT_EQ(8u, Log2U64(256)); - EXPECT_EQ(15u, Log2U64(65535)); - EXPECT_EQ(16u, Log2U64(65536)); - EXPECT_EQ(19u, Log2U64(0xFFFFF)); - EXPECT_EQ(23u, Log2U64(0xFFFFFF)); - EXPECT_EQ(27u, Log2U64(0xFFFFFFF)); - EXPECT_EQ(31u, Log2U64(0xFFFFFFFF)); - EXPECT_EQ(35u, Log2U64(0xFFFFFFFFF)); - EXPECT_EQ(39u, Log2U64(0xFFFFFFFFFF)); - EXPECT_EQ(43u, Log2U64(0xFFFFFFFFFFF)); - EXPECT_EQ(47u, Log2U64(0xFFFFFFFFFFFF)); - EXPECT_EQ(51u, Log2U64(0xFFFFFFFFFFFFF)); - EXPECT_EQ(55u, Log2U64(0xFFFFFFFFFFFFFF)); - EXPECT_EQ(59u, Log2U64(0xFFFFFFFFFFFFFFF)); - EXPECT_EQ(63u, Log2U64(0xFFFFFFFFFFFFFFFF)); -} - TEST(NumBitsToNumWords, Word8) { EXPECT_EQ(0u, NumBitsToNumWords<8>(0)); EXPECT_EQ(1u, NumBitsToNumWords<8>(1)); @@ -154,34 +163,6 @@ TEST(NumBitsToNumWords, Word64) { EXPECT_EQ(3u, NumBitsToNumWords<64>(129)); } -TEST(ZigZagCoding, Encode) { - EXPECT_EQ(0u, EncodeZigZag(0)); - EXPECT_EQ(1u, EncodeZigZag(-1)); - EXPECT_EQ(2u, EncodeZigZag(1)); - EXPECT_EQ(3u, EncodeZigZag(-2)); - EXPECT_EQ(4u, EncodeZigZag(2)); - EXPECT_EQ(5u, EncodeZigZag(-3)); - EXPECT_EQ(6u, EncodeZigZag(3)); - EXPECT_EQ(std::numeric_limits::max() - 1, - EncodeZigZag(std::numeric_limits::max())); - EXPECT_EQ(std::numeric_limits::max(), - EncodeZigZag(std::numeric_limits::min())); -} - -TEST(ZigZagCoding, Decode) { - EXPECT_EQ(0, DecodeZigZag(0)); - EXPECT_EQ(-1, DecodeZigZag(1)); - EXPECT_EQ(1, DecodeZigZag(2)); - EXPECT_EQ(-2, DecodeZigZag(3)); - EXPECT_EQ(2, DecodeZigZag(4)); - EXPECT_EQ(-3, DecodeZigZag(5)); - EXPECT_EQ(3, DecodeZigZag(6)); - EXPECT_EQ(std::numeric_limits::min(), - DecodeZigZag(std::numeric_limits::max())); - EXPECT_EQ(std::numeric_limits::max(), - DecodeZigZag(std::numeric_limits::max() - 1)); -} - TEST(ZigZagCoding, Encode0) { EXPECT_EQ(0u, EncodeZigZag(0, 0)); EXPECT_EQ(1u, EncodeZigZag(-1, 0)); @@ -204,18 +185,6 @@ TEST(ZigZagCoding, Decode0) { DecodeZigZag(std::numeric_limits::max() - 1, 0)); } -TEST(ZigZagCoding, Decode0SameAsNormalZigZag) { - for (int32_t i = -10000; i < 10000; i += 123) { - ASSERT_EQ(DecodeZigZag(i), DecodeZigZag(i, 0)); - } -} - -TEST(ZigZagCoding, Encode0SameAsNormalZigZag) { - for (uint32_t i = 0; i < 10000; i += 123) { - ASSERT_EQ(EncodeZigZag(i), EncodeZigZag(i, 0)); - } -} - TEST(ZigZagCoding, Encode1) { EXPECT_EQ(0u, EncodeZigZag(0, 1)); EXPECT_EQ(1u, EncodeZigZag(1, 1)); @@ -441,30 +410,6 @@ TEST(BitWriterStringStream, Empty) { EXPECT_EQ("", writer.GetStreamRaw()); } -TEST(BitWriterStringStream, WriteStream) { - BitWriterStringStream writer; - const std::string bits1 = "1011111111111111111"; - writer.WriteStream(bits1); - EXPECT_EQ(19u, writer.GetNumBits()); - EXPECT_EQ(3u, writer.GetDataSizeBytes()); - EXPECT_EQ(bits1, writer.GetStreamRaw()); - - const std::string bits2 = "10100001010101010000111111111111111111111111111"; - writer.WriteStream(bits2); - EXPECT_EQ(66u, writer.GetNumBits()); - EXPECT_EQ(9u, writer.GetDataSizeBytes()); - EXPECT_EQ(bits1 + bits2, writer.GetStreamRaw()); -} - -TEST(BitWriterStringStream, WriteBitSet) { - BitWriterStringStream writer; - const std::string bits1 = "10101"; - writer.WriteBitset(StreamToBitset<16>(bits1)); - EXPECT_EQ(16u, writer.GetNumBits()); - EXPECT_EQ(2u, writer.GetDataSizeBytes()); - EXPECT_EQ(PadToWord<16>(bits1), writer.GetStreamRaw()); -} - TEST(BitWriterStringStream, WriteBits) { BitWriterStringStream writer; const uint64_t bits1 = 0x1 | 0x2 | 0x10; @@ -495,20 +440,19 @@ TEST(BitWriterStringStream, WriteMultiple) { BitWriterStringStream writer; std::string expected_result; - const std::string bits1 = "101001111111001100010000001110001111111100"; - writer.WriteStream(bits1); - const std::string bits2 = "10100011000010010101"; - writer.WriteBitset(StreamToBitset<20>(bits2)); + const uint64_t b2_val = 0x4 | 0x2 | 0x40; + const std::string bits2 = BitsToStream(b2_val, 8); + writer.WriteBits(b2_val, 8); const uint64_t val = 0x1 | 0x2 | 0x10; const std::string bits3 = BitsToStream(val, 8); writer.WriteBits(val, 8); - const std::string expected = bits1 + bits2 + bits3; + const std::string expected = bits2 + bits3; EXPECT_EQ(expected.length(), writer.GetNumBits()); - EXPECT_EQ(9u, writer.GetDataSizeBytes()); + EXPECT_EQ(2u, writer.GetDataSizeBytes()); EXPECT_EQ(expected, writer.GetStreamRaw()); EXPECT_EQ(PadToWord<8>(expected), BufferToStream(writer.GetDataCopy())); @@ -518,46 +462,6 @@ TEST(BitWriterWord64, Empty) { BitWriterWord64 writer; EXPECT_EQ(0u, writer.GetNumBits()); EXPECT_EQ(0u, writer.GetDataSizeBytes()); - EXPECT_EQ("", writer.GetStreamPadded64()); -} - -TEST(BitWriterWord64, WriteStream) { - BitWriterWord64 writer; - std::string expected; - - { - const std::string bits = "101"; - expected += bits; - writer.WriteStream(bits); - EXPECT_EQ(expected.length(), writer.GetNumBits()); - EXPECT_EQ(1u, writer.GetDataSizeBytes()); - EXPECT_EQ(PadToWord<64>(expected), writer.GetStreamPadded64()); - } - - { - const std::string bits = "10000111111111110000000"; - expected += bits; - writer.WriteStream(bits); - EXPECT_EQ(expected.length(), writer.GetNumBits()); - EXPECT_EQ(PadToWord<64>(expected), writer.GetStreamPadded64()); - } - - { - const std::string bits = "101001111111111100000111111111111100"; - expected += bits; - writer.WriteStream(bits); - EXPECT_EQ(expected.length(), writer.GetNumBits()); - EXPECT_EQ(PadToWord<64>(expected), writer.GetStreamPadded64()); - } -} - -TEST(BitWriterWord64, WriteBitset) { - BitWriterWord64 writer; - const std::string bits1 = "10101"; - writer.WriteBitset(StreamToBitset<16>(bits1), 12); - EXPECT_EQ(12u, writer.GetNumBits()); - EXPECT_EQ(2u, writer.GetDataSizeBytes()); - EXPECT_EQ(PadToWord<64>(bits1), writer.GetStreamPadded64()); } TEST(BitWriterWord64, WriteBits) { @@ -568,7 +472,6 @@ TEST(BitWriterWord64, WriteBits) { writer.WriteBits(bits1, 5); EXPECT_EQ(15u, writer.GetNumBits()); EXPECT_EQ(2u, writer.GetDataSizeBytes()); - EXPECT_EQ(PadToWord<64>("110011100111001"), writer.GetStreamPadded64()); } TEST(BitWriterWord64, WriteZeroBits) { @@ -578,18 +481,11 @@ TEST(BitWriterWord64, WriteZeroBits) { EXPECT_EQ(0u, writer.GetNumBits()); writer.WriteBits(1, 1); writer.WriteBits(0, 0); - EXPECT_EQ(PadToWord<64>("1"), writer.GetStreamPadded64()); writer.WriteBits(0, 63); EXPECT_EQ(64u, writer.GetNumBits()); writer.WriteBits(0, 0); writer.WriteBits(7, 3); writer.WriteBits(0, 0); - EXPECT_EQ( - PadToWord<64>( - "1" - "000000000000000000000000000000000000000000000000000000000000000" - "111"), - writer.GetStreamPadded64()); } TEST(BitWriterWord64, ComparisonTestWriteLotsOfBits) { @@ -601,42 +497,6 @@ TEST(BitWriterWord64, ComparisonTestWriteLotsOfBits) { writer2.WriteBits(i, 16); ASSERT_EQ(writer1.GetNumBits(), writer2.GetNumBits()); } - - EXPECT_EQ(PadToWord<64>(writer1.GetStreamRaw()), writer2.GetStreamPadded64()); -} - -TEST(BitWriterWord64, ComparisonTestWriteLotsOfStreams) { - BitWriterStringStream writer1; - BitWriterWord64 writer2(16384); - - for (int i = 0; i < 1000; ++i) { - std::string bits = "1111100000"; - if (i % 2) bits += "101010"; - if (i % 3) bits += "1110100"; - if (i % 5) bits += "1110100111111111111"; - writer1.WriteStream(bits); - writer2.WriteStream(bits); - ASSERT_EQ(writer1.GetNumBits(), writer2.GetNumBits()); - } - - EXPECT_EQ(PadToWord<64>(writer1.GetStreamRaw()), writer2.GetStreamPadded64()); -} - -TEST(BitWriterWord64, ComparisonTestWriteLotsOfBitsets) { - BitWriterStringStream writer1; - BitWriterWord64 writer2(16384); - - for (uint64_t i = 0; i < 65000; i += 25) { - std::bitset<16> bits1(i); - std::bitset<24> bits2(i); - writer1.WriteBitset(bits1); - writer1.WriteBitset(bits2); - writer2.WriteBitset(bits1); - writer2.WriteBitset(bits2); - ASSERT_EQ(writer1.GetNumBits(), writer2.GetNumBits()); - } - - EXPECT_EQ(PadToWord<64>(writer1.GetStreamRaw()), writer2.GetStreamPadded64()); } TEST(GetLowerBits, Test) { @@ -674,7 +534,6 @@ TEST(BitReaderFromString, FromU8) { "10111011"; BitReaderFromString reader(buffer); - EXPECT_EQ(PadToWord<64>(total_stream), reader.GetStreamPadded64()); uint64_t bits = 0; EXPECT_EQ(2u, reader.ReadBits(&bits, 2)); @@ -703,7 +562,6 @@ TEST(BitReaderFromString, FromU64) { "1011101110111011101110111011101110111011101110111011101110111011"; BitReaderFromString reader(buffer); - EXPECT_EQ(total_stream, reader.GetStreamPadded64()); uint64_t bits = 0; size_t pos = 0; @@ -741,40 +599,6 @@ TEST(BitReaderWord64, ReadBitsSingleByte) { EXPECT_TRUE(reader.ReachedEnd()); } -TEST(BitReaderWord64, ReadBitsetSingleByte) { - BitReaderWord64 reader(std::vector({uint8_t(0xCC)})); - std::bitset<4> bits; - EXPECT_EQ(2u, reader.ReadBitset(&bits, 2)); - EXPECT_EQ(0u, bits.to_ullong()); - EXPECT_EQ(2u, reader.ReadBitset(&bits, 2)); - EXPECT_EQ(3u, bits.to_ullong()); - EXPECT_FALSE(reader.OnlyZeroesLeft()); - EXPECT_EQ(4u, reader.ReadBitset(&bits, 4)); - EXPECT_EQ(12u, bits.to_ullong()); - EXPECT_TRUE(reader.OnlyZeroesLeft()); -} - -TEST(BitReaderWord64, ReadStreamSingleByte) { - BitReaderWord64 reader(std::vector({uint8_t(0xAA)})); - EXPECT_EQ("", reader.ReadStream(0)); - EXPECT_EQ("0", reader.ReadStream(1)); - EXPECT_EQ("101", reader.ReadStream(3)); - EXPECT_EQ("01010000", reader.ReadStream(8)); - EXPECT_TRUE(reader.OnlyZeroesLeft()); - EXPECT_EQ("0000000000000000000000000000000000000000000000000000", - reader.ReadStream(64)); - EXPECT_TRUE(reader.ReachedEnd()); -} - -TEST(BitReaderWord64, ReadStreamEmpty) { - std::vector buffer; - BitReaderWord64 reader(std::move(buffer)); - EXPECT_TRUE(reader.OnlyZeroesLeft()); - EXPECT_TRUE(reader.ReachedEnd()); - EXPECT_EQ("", reader.ReadStream(10)); - EXPECT_TRUE(reader.ReachedEnd()); -} - TEST(BitReaderWord64, ReadBitsTwoWords) { std::vector buffer = {0x0000000000000001, 0x0000000000FFFFFF}; @@ -947,37 +771,6 @@ TEST(VariableWidthWrite, Write0U) { "000" "000", writer.GetStreamRaw()); - writer.WriteVariableWidthU8(0, 2); - EXPECT_EQ( - "000" - "000" - "000" - "000", - writer.GetStreamRaw()); -} - -TEST(VariableWidthWrite, Write0S) { - BitWriterStringStream writer; - writer.WriteVariableWidthS64(0, 2, 0); - EXPECT_EQ("000", writer.GetStreamRaw()); - writer.WriteVariableWidthS32(0, 2, 0); - EXPECT_EQ( - "000" - "000", - writer.GetStreamRaw()); - writer.WriteVariableWidthS16(0, 2, 0); - EXPECT_EQ( - "000" - "000" - "000", - writer.GetStreamRaw()); - writer.WriteVariableWidthS8(0, 2, 0); - EXPECT_EQ( - "000" - "000" - "000" - "000", - writer.GetStreamRaw()); } TEST(VariableWidthWrite, WriteSmallUnsigned) { @@ -995,13 +788,6 @@ TEST(VariableWidthWrite, WriteSmallUnsigned) { "010" "110", writer.GetStreamRaw()); - writer.WriteVariableWidthU8(4, 2); - EXPECT_EQ( - "100" - "010" - "110" - "001100", - writer.GetStreamRaw()); } TEST(VariableWidthWrite, WriteSmallSigned) { @@ -1013,19 +799,6 @@ TEST(VariableWidthWrite, WriteSmallSigned) { "010" "100", writer.GetStreamRaw()); - writer.WriteVariableWidthS16(3, 2, 0); - EXPECT_EQ( - "010" - "100" - "011100", - writer.GetStreamRaw()); - writer.WriteVariableWidthS8(-4, 2, 0); - EXPECT_EQ( - "010" - "100" - "011100" - "111100", - writer.GetStreamRaw()); } TEST(VariableWidthWrite, U64Val127ChunkLength7) { @@ -1057,16 +830,6 @@ TEST(VariableWidthWrite, U16Val2ChunkLength4) { writer.GetStreamRaw()); } -TEST(VariableWidthWrite, U8Val128ChunkLength7) { - BitWriterStringStream writer; - writer.WriteVariableWidthU8(128, 7); - EXPECT_EQ( - "0000000" - "1" - "1", - writer.GetStreamRaw()); -} - TEST(VariableWidthWrite, U64ValAAAAChunkLength2) { BitWriterStringStream writer; writer.WriteVariableWidthU64(0xAAAA, 2); @@ -1090,16 +853,6 @@ TEST(VariableWidthWrite, U64ValAAAAChunkLength2) { writer.GetStreamRaw()); } -TEST(VariableWidthWrite, S8ValM128ChunkLength7) { - BitWriterStringStream writer; - writer.WriteVariableWidthS8(-128, 7, 0); - EXPECT_EQ( - "1111111" - "1" - "1", - writer.GetStreamRaw()); -} - TEST(VariableWidthRead, U64Val127ChunkLength7) { BitReaderFromString reader( "1111111" @@ -1129,16 +882,6 @@ TEST(VariableWidthRead, U16Val2ChunkLength4) { EXPECT_EQ(2u, val); } -TEST(VariableWidthRead, U8Val128ChunkLength7) { - BitReaderFromString reader( - "0000000" - "1" - "1"); - uint8_t val = 0; - ASSERT_TRUE(reader.ReadVariableWidthU8(&val, 7)); - EXPECT_EQ(128u, val); -} - TEST(VariableWidthRead, U64ValAAAAChunkLength2) { BitReaderFromString reader( "01" @@ -1162,16 +905,6 @@ TEST(VariableWidthRead, U64ValAAAAChunkLength2) { EXPECT_EQ(0xAAAAu, val); } -TEST(VariableWidthRead, S8ValM128ChunkLength7) { - BitReaderFromString reader( - "1111111" - "1" - "1"); - int8_t val = 0; - ASSERT_TRUE(reader.ReadVariableWidthS8(&val, 7, 0)); - EXPECT_EQ(-128, val); -} - TEST(VariableWidthRead, FailTooShort) { BitReaderFromString reader("00000001100000"); uint64_t val = 0; @@ -1228,24 +961,6 @@ TEST(VariableWidthWriteRead, SingleWriteReadU32) { } } -TEST(VariableWidthWriteRead, SingleWriteReadS32) { - for (int32_t i = 0; i < 100000; i += 123) { - const int32_t val = i * (i % 2 ? -i : i); - const size_t chunk_length = i % 16 + 1; - const size_t zigzag_exponent = i % 11; - - BitWriterWord64 writer; - writer.WriteVariableWidthS32(val, chunk_length, zigzag_exponent); - - BitReaderWord64 reader(writer.GetDataCopy()); - int32_t read_val = 0; - ASSERT_TRUE( - reader.ReadVariableWidthS32(&read_val, chunk_length, zigzag_exponent)); - - ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length; - } -} - TEST(VariableWidthWriteRead, SingleWriteReadU16) { for (int i = 0; i < 65536; i += 123) { const uint16_t val = static_cast(i); @@ -1262,58 +977,6 @@ TEST(VariableWidthWriteRead, SingleWriteReadU16) { } } -TEST(VariableWidthWriteRead, SingleWriteReadS16) { - for (int i = -32768; i < 32768; i += 123) { - const int16_t val = static_cast(i); - const size_t chunk_length = std::abs(i) % 10 + 1; - const size_t zigzag_exponent = std::abs(i) % 7; - - BitWriterWord64 writer; - writer.WriteVariableWidthS16(val, chunk_length, zigzag_exponent); - - BitReaderWord64 reader(writer.GetDataCopy()); - int16_t read_val = 0; - ASSERT_TRUE( - reader.ReadVariableWidthS16(&read_val, chunk_length, zigzag_exponent)); - - ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length; - } -} - -TEST(VariableWidthWriteRead, SingleWriteReadU8) { - for (int i = 0; i < 256; ++i) { - const uint8_t val = static_cast(i); - const size_t chunk_length = val % 5 + 1; - - BitWriterWord64 writer; - writer.WriteVariableWidthU8(val, chunk_length); - - BitReaderWord64 reader(writer.GetDataCopy()); - uint8_t read_val = 0; - ASSERT_TRUE(reader.ReadVariableWidthU8(&read_val, chunk_length)); - - ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length; - } -} - -TEST(VariableWidthWriteRead, SingleWriteReadS8) { - for (int i = -128; i < 128; ++i) { - const int8_t val = static_cast(i); - const size_t chunk_length = std::abs(i) % 5 + 1; - const size_t zigzag_exponent = std::abs(i) % 3; - - BitWriterWord64 writer; - writer.WriteVariableWidthS8(val, chunk_length, zigzag_exponent); - - BitReaderWord64 reader(writer.GetDataCopy()); - int8_t read_val = 0; - ASSERT_TRUE( - reader.ReadVariableWidthS8(&read_val, chunk_length, zigzag_exponent)); - - ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length; - } -} - TEST(VariableWidthWriteRead, SmallNumbersChunkLength4) { const std::vector expected_values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; @@ -1357,68 +1020,6 @@ TEST(VariableWidthWriteRead, VariedNumbersChunkLength8) { EXPECT_EQ(expected_values, actual_values); } -TEST(FixedWidthWrite, Val0Max3) { - BitWriterStringStream writer; - writer.WriteFixedWidth(0, 3); - EXPECT_EQ("00", writer.GetStreamRaw()); -} - -TEST(FixedWidthWrite, Val0Max5) { - BitWriterStringStream writer; - writer.WriteFixedWidth(0, 5); - EXPECT_EQ("000", writer.GetStreamRaw()); -} - -TEST(FixedWidthWrite, Val0Max255) { - BitWriterStringStream writer; - writer.WriteFixedWidth(0, 255); - EXPECT_EQ("00000000", writer.GetStreamRaw()); -} - -TEST(FixedWidthWrite, Val3Max8) { - BitWriterStringStream writer; - writer.WriteFixedWidth(3, 8); - EXPECT_EQ("1100", writer.GetStreamRaw()); -} - -TEST(FixedWidthWrite, Val15Max127) { - BitWriterStringStream writer; - writer.WriteFixedWidth(15, 127); - EXPECT_EQ("1111000", writer.GetStreamRaw()); -} - -TEST(FixedWidthRead, Val0Max3) { - BitReaderFromString reader("0011111"); - uint64_t val = 0; - ASSERT_TRUE(reader.ReadFixedWidth(&val, 3)); - EXPECT_EQ(0u, val); -} - -TEST(FixedWidthRead, Val0Max5) { - BitReaderFromString reader("0001010101"); - uint64_t val = 0; - ASSERT_TRUE(reader.ReadFixedWidth(&val, 5)); - EXPECT_EQ(0u, val); -} - -TEST(FixedWidthRead, Val3Max8) { - BitReaderFromString reader("11001010101"); - uint64_t val = 0; - ASSERT_TRUE(reader.ReadFixedWidth(&val, 8)); - EXPECT_EQ(3u, val); -} - -TEST(FixedWidthRead, Val15Max127) { - BitReaderFromString reader("111100010101"); - uint64_t val = 0; - ASSERT_TRUE(reader.ReadFixedWidth(&val, 127)); - EXPECT_EQ(15u, val); -} - -TEST(FixedWidthRead, Fail) { - BitReaderFromString reader("111100"); - uint64_t val = 0; - ASSERT_FALSE(reader.ReadFixedWidth(&val, 127)); -} - -} // anonymous namespace +} // namespace +} // namespace comp +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/c_interface_test.cpp b/3rdparty/spirv-tools/test/c_interface_test.cpp index 9260549d6..1b735be5d 100644 --- a/3rdparty/spirv-tools/test/c_interface_test.cpp +++ b/3rdparty/spirv-tools/test/c_interface_test.cpp @@ -12,24 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - -#include "message.h" +#include "gtest/gtest.h" +#include "source/table.h" #include "spirv-tools/libspirv.h" -#include "table.h" +namespace spvtools { namespace { -using namespace spvtools; - // TODO(antiagainst): Use public C API for setting the consumer once exists. #ifndef SPIRV_TOOLS_SHAREDLIB -void SetContextMessageConsumer(spv_context context, - spvtools::MessageConsumer consumer) { - libspirv::SetContextMessageConsumer(context, consumer); +void SetContextMessageConsumer(spv_context context, MessageConsumer consumer) { + spvtools::SetContextMessageConsumer(context, consumer); } #else -void SetContextMessageConsumer(spv_context, spvtools::MessageConsumer) {} +void SetContextMessageConsumer(spv_context, MessageConsumer) {} #endif // The default consumer is a null std::function. @@ -194,8 +190,10 @@ TEST(CInterface, SpecifyConsumerNullDiagnosticForValidating) { // TODO(antiagainst): what validation reports is not a word offset here. // It is inconsistent with diassembler. Should be fixed. EXPECT_EQ(1u, position.index); - EXPECT_STREQ("Nop cannot appear before the memory model instruction", - message); + EXPECT_STREQ( + "Nop cannot appear before the memory model instruction\n" + " OpNop\n", + message); }); spv_binary binary = nullptr; @@ -287,12 +285,15 @@ TEST(CInterface, SpecifyConsumerSpecifyDiagnosticForValidating) { EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, spvValidate(context, &b, &diagnostic)); EXPECT_EQ(0, invocation); // Consumer should not be invoked at all. - EXPECT_STREQ("Nop cannot appear before the memory model instruction", - diagnostic->error); + EXPECT_STREQ( + "Nop cannot appear before the memory model instruction\n" + " OpNop\n", + diagnostic->error); spvDiagnosticDestroy(diagnostic); spvBinaryDestroy(binary); spvContextDestroy(context); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/comment_test.cpp b/3rdparty/spirv-tools/test/comment_test.cpp index f60b7918b..f46b72ac5 100644 --- a/3rdparty/spirv-tools/test/comment_test.cpp +++ b/3rdparty/spirv-tools/test/comment_test.cpp @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gmock/gmock.h" -#include "test_fixture.h" -#include "unit_spirv.h" +#include +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { namespace { using spvtest::Concatenate; @@ -43,4 +46,5 @@ TEST_F(TextToBinaryTest, Whitespace) { MakeVector("GLSL.std.450"))}))); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/comp/markv_codec_test.cpp b/3rdparty/spirv-tools/test/comp/markv_codec_test.cpp index a313d6ee3..76918f747 100644 --- a/3rdparty/spirv-tools/test/comp/markv_codec_test.cpp +++ b/3rdparty/spirv-tools/test/comp/markv_codec_test.cpp @@ -17,18 +17,19 @@ #include #include #include +#include #include "gmock/gmock.h" #include "source/comp/markv.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" #include "tools/comp/markv_model_factory.h" -#include "unit_spirv.h" +namespace spvtools { +namespace comp { namespace { -using libspirv::SetContextMessageConsumer; using spvtest::ScopedContext; -using spvtools::MarkvModelType; using MarkvTest = ::testing::TestWithParam; void DiagnosticsMessageHandler(spv_message_level_t level, const char*, @@ -90,9 +91,8 @@ void Disassemble(const std::vector& words, std::string* out_text, void TestEncodeDecode(MarkvModelType model_type, const std::string& original_text) { ScopedContext ctx(SPV_ENV_UNIVERSAL_1_2); - std::unique_ptr model = - spvtools::CreateMarkvModel(model_type); - spvtools::MarkvCodecOptions options; + std::unique_ptr model = CreateMarkvModel(model_type); + MarkvCodecOptions options; std::vector expected_binary; Compile(original_text, &expected_binary); @@ -112,18 +112,17 @@ void TestEncodeDecode(MarkvModelType model_type, [&encoder_comments](const std::string& str) { encoder_comments << str; }; std::vector markv; - ASSERT_EQ(SPV_SUCCESS, spvtools::SpirvToMarkv( - ctx.context, binary_to_encode, options, *model, - DiagnosticsMessageHandler, output_to_string_stream, - spvtools::MarkvDebugConsumer(), &markv)); + ASSERT_EQ(SPV_SUCCESS, + SpirvToMarkv(ctx.context, binary_to_encode, options, *model, + DiagnosticsMessageHandler, output_to_string_stream, + MarkvDebugConsumer(), &markv)); ASSERT_FALSE(markv.empty()); std::vector decoded_binary; ASSERT_EQ(SPV_SUCCESS, - spvtools::MarkvToSpirv( - ctx.context, markv, options, *model, DiagnosticsMessageHandler, - spvtools::MarkvLogConsumer(), spvtools::MarkvDebugConsumer(), - &decoded_binary)); + MarkvToSpirv(ctx.context, markv, options, *model, + DiagnosticsMessageHandler, MarkvLogConsumer(), + MarkvDebugConsumer(), &decoded_binary)); ASSERT_FALSE(decoded_binary.empty()); EXPECT_EQ(expected_binary, decoded_binary) << encoder_comments.str(); @@ -820,9 +819,11 @@ OpFunctionEnd INSTANTIATE_TEST_CASE_P(AllMarkvModels, MarkvTest, ::testing::ValuesIn(std::vector{ - spvtools::kMarkvModelShaderLite, - spvtools::kMarkvModelShaderMid, - spvtools::kMarkvModelShaderMax, + kMarkvModelShaderLite, + kMarkvModelShaderMid, + kMarkvModelShaderMax, }), ); } // namespace +} // namespace comp +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/cpp_interface_test.cpp b/3rdparty/spirv-tools/test/cpp_interface_test.cpp index 2bab430b2..bcc2cd6c7 100644 --- a/3rdparty/spirv-tools/test/cpp_interface_test.cpp +++ b/3rdparty/spirv-tools/test/cpp_interface_test.cpp @@ -12,18 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include +#include +#include +#include +#include "gmock/gmock.h" +#include "gtest/gtest.h" #include "spirv-tools/optimizer.hpp" #include "spirv/1.1/spirv.h" +namespace spvtools { namespace { -using namespace spvtools; using ::testing::ContainerEq; using ::testing::HasSubstr; +// Return a string that contains the minimum instructions needed to form +// a valid module. Other instructions can be appended to this string. +std::string Header() { + return R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +)"; +} + TEST(CppInterface, SuccessfulRoundTrip) { const std::string input_text = "%2 = OpSizeOf %1 %3\n"; SpirvTools t(SPV_ENV_UNIVERSAL_1_1); @@ -42,7 +54,7 @@ TEST(CppInterface, SuccessfulRoundTrip) { EXPECT_EQ(0u, position.line); EXPECT_EQ(0u, position.column); EXPECT_EQ(1u, position.index); - EXPECT_STREQ("ID 1 has not been defined", message); + EXPECT_STREQ("ID 1 has not been defined\n %2 = OpSizeOf %1 %3\n", message); }); EXPECT_FALSE(t.Validate(binary)); @@ -87,28 +99,6 @@ TEST(CppInterface, AssembleOverloads) { } } -TEST(CppInterface, AssembleWithWrongTargetEnv) { - const std::string input_text = "%r = OpSizeOf %type %pointer"; - SpirvTools t(SPV_ENV_UNIVERSAL_1_0); - int invocation_count = 0; - t.SetMessageConsumer( - [&invocation_count](spv_message_level_t level, const char* source, - const spv_position_t& position, const char* message) { - ++invocation_count; - EXPECT_EQ(SPV_MSG_ERROR, level); - EXPECT_STREQ("input", source); - EXPECT_EQ(0u, position.line); - EXPECT_EQ(5u, position.column); - EXPECT_EQ(5u, position.index); - EXPECT_STREQ("Invalid Opcode name 'OpSizeOf'", message); - }); - - std::vector binary = {42, 42}; - EXPECT_FALSE(t.Assemble(input_text, &binary)); - EXPECT_THAT(binary, ContainerEq(std::vector{42, 42})); - EXPECT_EQ(1, invocation_count); -} - TEST(CppInterface, DisassembleEmptyModule) { std::string text(10, 'x'); SpirvTools t(SPV_ENV_UNIVERSAL_1_1); @@ -148,36 +138,7 @@ TEST(CppInterface, DisassembleOverloads) { } } -TEST(CppInterface, DisassembleWithWrongTargetEnv) { - const std::string input_text = "%r = OpSizeOf %type %pointer"; - SpirvTools t11(SPV_ENV_UNIVERSAL_1_1); - SpirvTools t10(SPV_ENV_UNIVERSAL_1_0); - int invocation_count = 0; - t10.SetMessageConsumer( - [&invocation_count](spv_message_level_t level, const char* source, - const spv_position_t& position, const char* message) { - ++invocation_count; - EXPECT_EQ(SPV_MSG_ERROR, level); - EXPECT_STREQ("input", source); - EXPECT_EQ(0u, position.line); - EXPECT_EQ(0u, position.column); - EXPECT_EQ(5u, position.index); - EXPECT_STREQ("Invalid opcode: 321", message); - }); - - std::vector binary; - EXPECT_TRUE(t11.Assemble(input_text, &binary)); - - std::string output_text(10, 'x'); - EXPECT_FALSE(t10.Disassemble(binary, &output_text)); - EXPECT_EQ("xxxxxxxxxx", output_text); // The original string is unmodified. -} - TEST(CppInterface, SuccessfulValidation) { - const std::string input_text = R"( - OpCapability Shader - OpCapability Linkage - OpMemoryModel Logical GLSL450)"; SpirvTools t(SPV_ENV_UNIVERSAL_1_1); int invocation_count = 0; t.SetMessageConsumer([&invocation_count](spv_message_level_t, const char*, @@ -186,19 +147,15 @@ TEST(CppInterface, SuccessfulValidation) { }); std::vector binary; - EXPECT_TRUE(t.Assemble(input_text, &binary)); + EXPECT_TRUE(t.Assemble(Header(), &binary)); EXPECT_TRUE(t.Validate(binary)); EXPECT_EQ(0, invocation_count); } TEST(CppInterface, ValidateOverloads) { - const std::string input_text = R"( - OpCapability Shader - OpCapability Linkage - OpMemoryModel Logical GLSL450)"; SpirvTools t(SPV_ENV_UNIVERSAL_1_1); std::vector binary; - EXPECT_TRUE(t.Assemble(input_text, &binary)); + EXPECT_TRUE(t.Assemble(Header(), &binary)); { EXPECT_TRUE(t.Validate(binary)); } { EXPECT_TRUE(t.Validate(binary.data(), binary.size())); } @@ -226,11 +183,9 @@ TEST(CppInterface, ValidateEmptyModule) { // with the given number of members. std::string MakeModuleHavingStruct(int num_members) { std::stringstream os; - os << R"(OpCapability Shader - OpCapability Linkage - OpMemoryModel Logical GLSL450 - %1 = OpTypeInt 32 0 - %2 = OpTypeStruct)"; + os << Header(); + os << R"(%1 = OpTypeInt 32 0 + %2 = OpTypeStruct)"; for (int i = 0; i < num_members; i++) os << " %1"; return os.str(); } @@ -239,7 +194,7 @@ TEST(CppInterface, ValidateWithOptionsPass) { SpirvTools t(SPV_ENV_UNIVERSAL_1_1); std::vector binary; EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary)); - const spvtools::ValidatorOptions opts; + const ValidatorOptions opts; EXPECT_TRUE(t.Validate(binary.data(), binary.size(), opts)); } @@ -248,7 +203,7 @@ TEST(CppInterface, ValidateWithOptionsFail) { SpirvTools t(SPV_ENV_UNIVERSAL_1_1); std::vector binary; EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary)); - spvtools::ValidatorOptions opts; + ValidatorOptions opts; opts.SetUniversalLimit(spv_validator_limit_max_struct_members, 9); std::stringstream os; t.SetMessageConsumer([&os](spv_message_level_t, const char*, @@ -264,8 +219,8 @@ TEST(CppInterface, ValidateWithOptionsFail) { // Checks that after running the given optimizer |opt| on the given |original| // source code, we can get the given |optimized| source code. -void CheckOptimization(const char* original, const char* optimized, - const Optimizer& opt) { +void CheckOptimization(const std::string& original, + const std::string& optimized, const Optimizer& opt) { SpirvTools t(SPV_ENV_UNIVERSAL_1_1); std::vector original_binary; ASSERT_TRUE(t.Assemble(original, &original_binary)); @@ -286,29 +241,31 @@ TEST(CppInterface, OptimizeEmptyModule) { Optimizer o(SPV_ENV_UNIVERSAL_1_1); o.RegisterPass(CreateStripDebugInfoPass()); - EXPECT_TRUE(o.Run(binary.data(), binary.size(), &binary)); + + // Fails to validate. + EXPECT_FALSE(o.Run(binary.data(), binary.size(), &binary)); } TEST(CppInterface, OptimizeModifiedModule) { Optimizer o(SPV_ENV_UNIVERSAL_1_1); o.RegisterPass(CreateStripDebugInfoPass()); - CheckOptimization("OpSource GLSL 450", "", o); + CheckOptimization(Header() + "OpSource GLSL 450", Header(), o); } TEST(CppInterface, OptimizeMulitplePasses) { - const char* original_text = - "OpSource GLSL 450 " - "OpDecorate %true SpecId 1 " - "%bool = OpTypeBool " - "%true = OpSpecConstantTrue %bool"; + std::string original_text = Header() + + "OpSource GLSL 450 " + "OpDecorate %true SpecId 1 " + "%bool = OpTypeBool " + "%true = OpSpecConstantTrue %bool"; Optimizer o(SPV_ENV_UNIVERSAL_1_1); o.RegisterPass(CreateStripDebugInfoPass()) .RegisterPass(CreateFreezeSpecConstantValuePass()); - const char* expected_text = - "%bool = OpTypeBool\n" - "%true = OpConstantTrue %bool\n"; + std::string expected_text = Header() + + "%bool = OpTypeBool\n" + "%true = OpConstantTrue %bool\n"; CheckOptimization(original_text, expected_text, o); } @@ -323,7 +280,7 @@ TEST(CppInterface, OptimizeReassignPassToken) { token = CreateStripDebugInfoPass(); CheckOptimization( - "OpSource GLSL 450", "", + Header() + "OpSource GLSL 450", Header(), Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token))); } @@ -332,7 +289,7 @@ TEST(CppInterface, OptimizeMoveConstructPassToken) { Optimizer::PassToken token2(std::move(token1)); CheckOptimization( - "OpSource GLSL 450", "", + Header() + "OpSource GLSL 450", Header(), Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2))); } @@ -342,14 +299,14 @@ TEST(CppInterface, OptimizeMoveAssignPassToken) { token2 = std::move(token1); CheckOptimization( - "OpSource GLSL 450", "", + Header() + "OpSource GLSL 450", Header(), Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2))); } TEST(CppInterface, OptimizeSameAddressForOriginalOptimizedBinary) { SpirvTools t(SPV_ENV_UNIVERSAL_1_1); std::vector binary; - ASSERT_TRUE(t.Assemble("OpSource GLSL 450", &binary)); + ASSERT_TRUE(t.Assemble(Header() + "OpSource GLSL 450", &binary)); EXPECT_TRUE(Optimizer(SPV_ENV_UNIVERSAL_1_1) .RegisterPass(CreateStripDebugInfoPass()) @@ -357,9 +314,10 @@ TEST(CppInterface, OptimizeSameAddressForOriginalOptimizedBinary) { std::string optimized_text; EXPECT_TRUE(t.Disassemble(binary, &optimized_text)); - EXPECT_EQ("", optimized_text); + EXPECT_EQ(Header(), optimized_text); } // TODO(antiagainst): tests for SetMessageConsumer(). -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/diagnostic_test.cpp b/3rdparty/spirv-tools/test/diagnostic_test.cpp index 8b8dbbe34..f86bae113 100644 --- a/3rdparty/spirv-tools/test/diagnostic_test.cpp +++ b/3rdparty/spirv-tools/test/diagnostic_test.cpp @@ -14,13 +14,14 @@ #include #include +#include #include "gmock/gmock.h" -#include "unit_spirv.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { -using libspirv::DiagnosticStream; using ::testing::Eq; // Returns a newly created diagnostic value. @@ -68,16 +69,16 @@ TEST(Diagnostic, PrintInvalidDiagnostic) { TEST(DiagnosticStream, ConversionToResultType) { // Check after the DiagnosticStream object is destroyed. spv_result_t value; - { value = DiagnosticStream({}, nullptr, SPV_ERROR_INVALID_TEXT); } + { value = DiagnosticStream({}, nullptr, "", SPV_ERROR_INVALID_TEXT); } EXPECT_EQ(SPV_ERROR_INVALID_TEXT, value); // Check implicit conversion via plain assignment. - value = DiagnosticStream({}, nullptr, SPV_SUCCESS); + value = DiagnosticStream({}, nullptr, "", SPV_SUCCESS); EXPECT_EQ(SPV_SUCCESS, value); // Check conversion via constructor. EXPECT_EQ(SPV_FAILED_MATCH, - spv_result_t(DiagnosticStream({}, nullptr, SPV_FAILED_MATCH))); + spv_result_t(DiagnosticStream({}, nullptr, "", SPV_FAILED_MATCH))); } TEST( @@ -94,7 +95,7 @@ TEST( // Enclose the DiagnosticStream variables in a scope to force destruction. { - DiagnosticStream ds0({}, consumer, SPV_ERROR_INVALID_BINARY); + DiagnosticStream ds0({}, consumer, "", SPV_ERROR_INVALID_BINARY); ds0 << "First"; DiagnosticStream ds1(std::move(ds0)); ds1 << "Second"; @@ -103,4 +104,47 @@ TEST( EXPECT_THAT(messages.str(), Eq("FirstSecond")); } -} // anonymous namespace +TEST(DiagnosticStream, MoveConstructorCanBeDirectlyShiftedTo) { + std::ostringstream messages; + int message_count = 0; + auto consumer = [&messages, &message_count](spv_message_level_t, const char*, + const spv_position_t&, + const char* msg) { + message_count++; + messages << msg; + }; + + // Enclose the DiagnosticStream variables in a scope to force destruction. + { + DiagnosticStream ds0({}, consumer, "", SPV_ERROR_INVALID_BINARY); + ds0 << "First"; + std::move(ds0) << "Second"; + } + EXPECT_THAT(message_count, Eq(1)); + EXPECT_THAT(messages.str(), Eq("FirstSecond")); +} + +TEST(DiagnosticStream, DiagnosticFromLambdaReturnCanStillBeUsed) { + std::ostringstream messages; + int message_count = 0; + auto consumer = [&messages, &message_count](spv_message_level_t, const char*, + const spv_position_t&, + const char* msg) { + message_count++; + messages << msg; + }; + + { + auto emitter = [&consumer]() -> DiagnosticStream { + DiagnosticStream ds0({}, consumer, "", SPV_ERROR_INVALID_BINARY); + ds0 << "First"; + return ds0; + }; + emitter() << "Second"; + } + EXPECT_THAT(message_count, Eq(1)); + EXPECT_THAT(messages.str(), Eq("FirstSecond")); +} + +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/enum_set_test.cpp b/3rdparty/spirv-tools/test/enum_set_test.cpp index 81671e637..ddacd4214 100644 --- a/3rdparty/spirv-tools/test/enum_set_test.cpp +++ b/3rdparty/spirv-tools/test/enum_set_test.cpp @@ -12,16 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include + #include "gmock/gmock.h" +#include "source/enum_set.h" +#include "test/unit_spirv.h" -#include "enum_set.h" -#include "unit_spirv.h" - +namespace spvtools { namespace { -using libspirv::CapabilitySet; -using libspirv::EnumSet; using spvtest::ElementsIn; using ::testing::Eq; using ::testing::ValuesIn; @@ -285,4 +286,5 @@ INSTANTIATE_TEST_CASE_P(Samples, CapabilitySetForEachTest, static_cast(0x7fffffff)}}, }), ); -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/enum_string_mapping_test.cpp b/3rdparty/spirv-tools/test/enum_string_mapping_test.cpp index f7f1ef505..b525d6014 100644 --- a/3rdparty/spirv-tools/test/enum_string_mapping_test.cpp +++ b/3rdparty/spirv-tools/test/enum_string_mapping_test.cpp @@ -15,15 +15,16 @@ // Tests for OpExtension validator rules. #include +#include +#include -#include "enum_string_mapping.h" -#include "extensions.h" #include "gtest/gtest.h" +#include "source/enum_string_mapping.h" +#include "source/extensions.h" +namespace spvtools { namespace { -using ::libspirv::Extension; - using ::testing::Values; using ::testing::ValuesIn; @@ -38,8 +39,7 @@ TEST_P(ExtensionTest, TestExtensionFromString) { const Extension extension = param.first; const std::string extension_str = param.second; Extension result_extension; - ASSERT_TRUE(libspirv::GetExtensionFromString(extension_str.c_str(), - &result_extension)); + ASSERT_TRUE(GetExtensionFromString(extension_str.c_str(), &result_extension)); EXPECT_EQ(extension, result_extension); } @@ -47,21 +47,20 @@ TEST_P(ExtensionTest, TestExtensionToString) { const std::pair& param = GetParam(); const Extension extension = param.first; const std::string extension_str = param.second; - const std::string result_str = libspirv::ExtensionToString(extension); + const std::string result_str = ExtensionToString(extension); EXPECT_EQ(extension_str, result_str); } TEST_P(UnknownExtensionTest, TestExtensionFromStringFails) { Extension result_extension; - ASSERT_FALSE( - libspirv::GetExtensionFromString(GetParam().c_str(), &result_extension)); + ASSERT_FALSE(GetExtensionFromString(GetParam().c_str(), &result_extension)); } TEST_P(CapabilityTest, TestCapabilityToString) { const std::pair& param = GetParam(); const SpvCapability capability = param.first; const std::string capability_str = param.second; - const std::string result_str = libspirv::CapabilityToString(capability); + const std::string result_str = CapabilityToString(capability); EXPECT_EQ(capability_str, result_str); } @@ -87,6 +86,7 @@ INSTANTIATE_TEST_CASE_P( {Extension::kSPV_GOOGLE_decorate_string, "SPV_GOOGLE_decorate_string"}, {Extension::kSPV_GOOGLE_hlsl_functionality1, "SPV_GOOGLE_hlsl_functionality1"}, + {Extension::kSPV_KHR_8bit_storage, "SPV_KHR_8bit_storage"}, }))); INSTANTIATE_TEST_CASE_P(UnknownExtensions, UnknownExtensionTest, @@ -191,4 +191,5 @@ INSTANTIATE_TEST_CASE_P( {SpvCapabilityShaderStereoViewNV, "ShaderStereoViewNV"}, {SpvCapabilityPerViewAttributesNV, "PerViewAttributesNV"}})), ); -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/ext_inst.debuginfo_test.cpp b/3rdparty/spirv-tools/test/ext_inst.debuginfo_test.cpp index 608d11992..15fa8f765 100644 --- a/3rdparty/spirv-tools/test/ext_inst.debuginfo_test.cpp +++ b/3rdparty/spirv-tools/test/ext_inst.debuginfo_test.cpp @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include +#include -#include #include "DebugInfo.h" -#include "test_fixture.h" +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" // This file tests the correctness of encoding and decoding of instructions // involving the DebugInfo extended instruction set. @@ -24,6 +26,7 @@ // // See https://www.khronos.org/registry/spir-v/specs/1.0/DebugInfo.html +namespace spvtools { namespace { using spvtest::Concatenate; @@ -805,4 +808,5 @@ INSTANTIATE_TEST_CASE_P(DebugInfoDebugMacroUndef, ExtInstDebugInfoRoundTripTest, #undef CASE_EL #undef CASE_ELL -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/ext_inst.glsl_test.cpp b/3rdparty/spirv-tools/test/ext_inst.glsl_test.cpp index 5e569ead2..991c487f1 100644 --- a/3rdparty/spirv-tools/test/ext_inst.glsl_test.cpp +++ b/3rdparty/spirv-tools/test/ext_inst.glsl_test.cpp @@ -13,11 +13,13 @@ // limitations under the License. #include +#include #include -#include "latest_version_glsl_std_450_header.h" -#include "unit_spirv.h" +#include "source/latest_version_glsl_std_450_header.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { /// Context for an extended instruction. @@ -197,4 +199,5 @@ INSTANTIATE_TEST_CASE_P( {"NClamp", "%5 %5 %5", 81, 8, {5, 5, 5}}, })), ); -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/ext_inst.opencl_test.cpp b/3rdparty/spirv-tools/test/ext_inst.opencl_test.cpp index 6829e3cab..06bc5e848 100644 --- a/3rdparty/spirv-tools/test/ext_inst.opencl_test.cpp +++ b/3rdparty/spirv-tools/test/ext_inst.opencl_test.cpp @@ -12,12 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include +#include -#include -#include "latest_version_opencl_std_header.h" -#include "test_fixture.h" +#include "gmock/gmock.h" +#include "source/latest_version_opencl_std_header.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::Concatenate; @@ -366,4 +369,5 @@ INSTANTIATE_TEST_CASE_P( #undef CASE2Lit #undef CASE3Round -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/fix_word_test.cpp b/3rdparty/spirv-tools/test/fix_word_test.cpp index 45ba6e637..b8c3a33d5 100644 --- a/3rdparty/spirv-tools/test/fix_word_test.cpp +++ b/3rdparty/spirv-tools/test/fix_word_test.cpp @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { TEST(FixWord, Default) { @@ -59,4 +60,5 @@ TEST(FixDoubleWord, Reorder) { ASSERT_EQ(result, spvFixDoubleWord(low, high, endian)); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/fuzzers/BUILD.gn b/3rdparty/spirv-tools/test/fuzzers/BUILD.gn new file mode 100644 index 000000000..df8291a56 --- /dev/null +++ b/3rdparty/spirv-tools/test/fuzzers/BUILD.gn @@ -0,0 +1,121 @@ +# Copyright 2018 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import("//testing/libfuzzer/fuzzer_test.gni") +import("//testing/test.gni") + +config("fuzzer_config") { + configs = [ "../..:spvtools_config" ] +} + +group("fuzzers") { + testonly = true + deps = [] + + if (!build_with_chromium || use_fuzzing_engine) { + deps += [ ":fuzzers_bin" ] + } +} + +if (!build_with_chromium || use_fuzzing_engine) { + group("fuzzers_bin") { + testonly = true + + deps = [ + ":spvtools_val_fuzzer", + ":spvtools_opt_legalization_fuzzer", + ":spvtools_opt_performance_fuzzer", + ":spvtools_opt_size_fuzzer", + ] + } +} + +template("spvtools_fuzzer") { + source_set(target_name) { + testonly = true + sources = invoker.sources + deps = [ + "../..:spvtools", + "../..:spvtools_opt", + "../..:spvtools_val", + ] + if (defined(invoker.deps)) { + deps += invoker.deps + } + + configs -= [ "//build/config/compiler:chromium_code" ] + configs += [ + "//build/config/compiler:no_chromium_code", + ":fuzzer_config", + ] + } +} + +spvtools_fuzzer("spvtools_opt_performance_fuzzer_src") { + sources = [ + "spvtools_opt_performance_fuzzer.cpp", + ] +} + +spvtools_fuzzer("spvtools_opt_legalization_fuzzer_src") { + sources = [ + "spvtools_opt_legalization_fuzzer.cpp", + ] +} + +spvtools_fuzzer("spvtools_opt_size_fuzzer_src") { + sources = [ + "spvtools_opt_size_fuzzer.cpp", + ] +} + +spvtools_fuzzer("spvtools_val_fuzzer_src") { + sources = [ + "spvtools_val_fuzzer.cpp", + ] +} + +if (!build_with_chromium || use_fuzzing_engine) { + fuzzer_test("spvtools_opt_performance_fuzzer") { + sources = [] + deps = [ + ":spvtools_opt_performance_fuzzer_src", + ] + seed_corpus = "corpora/spv" + } + + fuzzer_test("spvtools_opt_legalization_fuzzer") { + sources = [] + deps = [ + ":spvtools_opt_legalization_fuzzer_src", + ] + seed_corpus = "corpora/spv" + } + + fuzzer_test("spvtools_opt_size_fuzzer") { + sources = [] + deps = [ + ":spvtools_opt_size_fuzzer_src", + ] + seed_corpus = "corpora/spv" + } + + fuzzer_test("spvtools_val_fuzzer") { + sources = [] + deps = [ + ":spvtools_val_fuzzer_src", + ] + seed_corpus = "corpora/spv" + } +} diff --git a/3rdparty/spirv-tools/test/fuzzers/corpora/spv/simple.spv b/3rdparty/spirv-tools/test/fuzzers/corpora/spv/simple.spv new file mode 100644 index 000000000..f972a56fd Binary files /dev/null and b/3rdparty/spirv-tools/test/fuzzers/corpora/spv/simple.spv differ diff --git a/3rdparty/spirv-tools/test/fuzzers/spvtools_opt_legalization_fuzzer.cpp b/3rdparty/spirv-tools/test/fuzzers/spvtools_opt_legalization_fuzzer.cpp new file mode 100644 index 000000000..b45a98c37 --- /dev/null +++ b/3rdparty/spirv-tools/test/fuzzers/spvtools_opt_legalization_fuzzer.cpp @@ -0,0 +1,38 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "spirv-tools/optimizer.hpp" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + spvtools::Optimizer optimizer(SPV_ENV_UNIVERSAL_1_3); + optimizer.SetMessageConsumer([](spv_message_level_t, const char*, + const spv_position_t&, const char*) {}); + + std::vector input; + input.resize(size >> 2); + + size_t count = 0; + for (size_t i = 0; (i + 3) < size; i += 4) { + input[count++] = data[i] | (data[i + 1] << 8) | (data[i + 2] << 16) | + (data[i + 3]) << 24; + } + + optimizer.RegisterLegalizationPasses(); + optimizer.Run(input.data(), input.size(), &input); + + return 0; +} diff --git a/3rdparty/spirv-tools/test/fuzzers/spvtools_opt_performance_fuzzer.cpp b/3rdparty/spirv-tools/test/fuzzers/spvtools_opt_performance_fuzzer.cpp new file mode 100644 index 000000000..6c3bd6aba --- /dev/null +++ b/3rdparty/spirv-tools/test/fuzzers/spvtools_opt_performance_fuzzer.cpp @@ -0,0 +1,38 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "spirv-tools/optimizer.hpp" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + spvtools::Optimizer optimizer(SPV_ENV_UNIVERSAL_1_3); + optimizer.SetMessageConsumer([](spv_message_level_t, const char*, + const spv_position_t&, const char*) {}); + + std::vector input; + input.resize(size >> 2); + + size_t count = 0; + for (size_t i = 0; (i + 3) < size; i += 4) { + input[count++] = data[i] | (data[i + 1] << 8) | (data[i + 2] << 16) | + (data[i + 3]) << 24; + } + + optimizer.RegisterPerformancePasses(); + optimizer.Run(input.data(), input.size(), &input); + + return 0; +} diff --git a/3rdparty/spirv-tools/test/fuzzers/spvtools_opt_size_fuzzer.cpp b/3rdparty/spirv-tools/test/fuzzers/spvtools_opt_size_fuzzer.cpp new file mode 100644 index 000000000..68c797477 --- /dev/null +++ b/3rdparty/spirv-tools/test/fuzzers/spvtools_opt_size_fuzzer.cpp @@ -0,0 +1,38 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "spirv-tools/optimizer.hpp" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + spvtools::Optimizer optimizer(SPV_ENV_UNIVERSAL_1_3); + optimizer.SetMessageConsumer([](spv_message_level_t, const char*, + const spv_position_t&, const char*) {}); + + std::vector input; + input.resize(size >> 2); + + size_t count = 0; + for (size_t i = 0; (i + 3) < size; i += 4) { + input[count++] = data[i] | (data[i + 1] << 8) | (data[i + 2] << 16) | + (data[i + 3]) << 24; + } + + optimizer.RegisterSizePasses(); + optimizer.Run(input.data(), input.size(), &input); + + return 0; +} diff --git a/3rdparty/spirv-tools/test/fuzzers/spvtools_val_fuzzer.cpp b/3rdparty/spirv-tools/test/fuzzers/spvtools_val_fuzzer.cpp new file mode 100644 index 000000000..5dc4303b4 --- /dev/null +++ b/3rdparty/spirv-tools/test/fuzzers/spvtools_val_fuzzer.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "spirv-tools/libspirv.hpp" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_3); + tools.SetMessageConsumer([](spv_message_level_t, const char*, + const spv_position_t&, const char*) {}); + + std::vector input; + input.resize(size >> 2); + + size_t count = 0; + for (size_t i = 0; (i + 3) < size; i += 4) { + input[count++] = data[i] | (data[i + 1] << 8) | (data[i + 2] << 16) | + (data[i + 3]) << 24; + } + + tools.Validate(input); + return 0; +} diff --git a/3rdparty/spirv-tools/test/generator_magic_number_test.cpp b/3rdparty/spirv-tools/test/generator_magic_number_test.cpp index c88022899..bc5fdf57a 100644 --- a/3rdparty/spirv-tools/test/generator_magic_number_test.cpp +++ b/3rdparty/spirv-tools/test/generator_magic_number_test.cpp @@ -12,17 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" - -#include +#include +#include +#include +#include +#include "gmock/gmock.h" #include "source/opcode.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { using ::spvtest::EnumCase; using ::testing::Eq; - -namespace { - using GeneratorMagicNumberTest = ::testing::TestWithParam>; @@ -54,4 +57,6 @@ INSTANTIATE_TEST_CASE_P( {spv_generator_t(1000), "Unknown"}, {spv_generator_t(9999), "Unknown"}, }), ); -} // anonymous namespace + +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/hex_float_test.cpp b/3rdparty/spirv-tools/test/hex_float_test.cpp index a2bff53ed..87450609f 100644 --- a/3rdparty/spirv-tools/test/hex_float_test.cpp +++ b/3rdparty/spirv-tools/test/hex_float_test.cpp @@ -15,21 +15,21 @@ #include #include #include +#include #include #include #include +#include +#include -#include - +#include "gmock/gmock.h" #include "source/util/hex_float.h" -#include "unit_spirv.h" +#include "test/unit_spirv.h" +namespace spvtools { +namespace utils { namespace { -using spvutils::BitwiseCast; -using spvutils::Float16; -using spvutils::FloatProxy; -using spvutils::HexFloat; -using spvutils::ParseNormalFloat; + using ::testing::Eq; // In this file "encode" means converting a number into a string, @@ -50,7 +50,7 @@ using RoundTripDoubleTest = ::testing::TestWithParam; template std::string EncodeViaHexFloat(const T& value) { std::stringstream ss; - ss << spvutils::HexFloat(value); + ss << HexFloat(value); return ss.str(); } @@ -68,7 +68,7 @@ TEST_P(HexDoubleTest, EncodeCorrectly) { // Decodes a hex-float string. template FloatProxy Decode(const std::string& str) { - spvutils::HexFloat> decoded(0.f); + HexFloat> decoded(0.f); EXPECT_TRUE((std::stringstream(str) >> decoded).eof()); return decoded.value(); } @@ -229,8 +229,8 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(std::vector< std::pair, std::string>>({ // Various NAN and INF cases - {uint64_t(0xFFF0000000000000LL), "-0x1p+1024"}, //-inf - {uint64_t(0x7FF0000000000000LL), "0x1p+1024"}, //+inf + {uint64_t(0xFFF0000000000000LL), "-0x1p+1024"}, // -inf + {uint64_t(0x7FF0000000000000LL), "0x1p+1024"}, // +inf {uint64_t(0xFFF8000000000000LL), "-0x1.8p+1024"}, // -nan {uint64_t(0xFFF0F00000000000LL), "-0x1.0fp+1024"}, // -nan {uint64_t(0xFFF0000000000001LL), "-0x1.0000000000001p+1024"}, // -nan @@ -575,13 +575,12 @@ INSTANTIATE_TEST_CASE_P( // double is used so that unbiased_exponent can be used with the output // of ldexp directly. int32_t unbiased_exponent(double f) { - return spvutils::HexFloat>(static_cast(f)) + return HexFloat>(static_cast(f)) .getUnbiasedNormalizedExponent(); } int16_t unbiased_half_exponent(uint16_t f) { - return spvutils::HexFloat>(f) - .getUnbiasedNormalizedExponent(); + return HexFloat>(f).getUnbiasedNormalizedExponent(); } TEST(HexFloatOperationTest, UnbiasedExponent) { @@ -591,9 +590,9 @@ TEST(HexFloatOperationTest, UnbiasedExponent) { EXPECT_EQ(42, unbiased_exponent(ldexp(1.0f, 42))); EXPECT_EQ(125, unbiased_exponent(ldexp(1.0f, 125))); - EXPECT_EQ(128, spvutils::HexFloat>( - std::numeric_limits::infinity()) - .getUnbiasedNormalizedExponent()); + EXPECT_EQ(128, + HexFloat>(std::numeric_limits::infinity()) + .getUnbiasedNormalizedExponent()); EXPECT_EQ(-100, unbiased_exponent(ldexp(1.0f, -100))); EXPECT_EQ(-127, unbiased_exponent(ldexp(1.0f, -127))); // First denorm @@ -633,7 +632,7 @@ float float_fractions(const std::vector& fractions) { // raised to the power of exp. uint32_t normalized_significand(const std::vector& fractions, uint32_t exp) { - return spvutils::HexFloat>( + return HexFloat>( static_cast(ldexp(float_fractions(fractions), exp))) .getNormalizedSignificand(); } @@ -674,11 +673,16 @@ TEST(HexFloatOperationTest, NormalizedSignificand) { // For denormalized numbers we expect the normalized significand to // shift as if it were normalized. This means, in practice that the // top_most set bit will be cut off. Looks very similar to above (on purpose) - EXPECT_EQ(bits_set({}), normalized_significand({0}, -127)); - EXPECT_EQ(bits_set({3}), normalized_significand({0, 4}, -128)); - EXPECT_EQ(bits_set({3}), normalized_significand({0, 4}, -127)); - EXPECT_EQ(bits_set({}), normalized_significand({22}, -127)); - EXPECT_EQ(bits_set({0}), normalized_significand({21, 22}, -127)); + EXPECT_EQ(bits_set({}), + normalized_significand({0}, static_cast(-127))); + EXPECT_EQ(bits_set({3}), + normalized_significand({0, 4}, static_cast(-128))); + EXPECT_EQ(bits_set({3}), + normalized_significand({0, 4}, static_cast(-127))); + EXPECT_EQ(bits_set({}), + normalized_significand({22}, static_cast(-127))); + EXPECT_EQ(bits_set({0}), + normalized_significand({21, 22}, static_cast(-127))); } // Returns the 32-bit floating point value created by @@ -686,7 +690,7 @@ TEST(HexFloatOperationTest, NormalizedSignificand) { // on a HexFloat> float set_from_sign(bool negative, int32_t unbiased_exponent, uint32_t significand, bool round_denorm_up) { - spvutils::HexFloat> f(0.f); + HexFloat> f(0.f); f.setFromSignUnbiasedExponentAndNormalizedSignificand( negative, unbiased_exponent, significand, round_denorm_up); return f.value().getAsFloat(); @@ -729,17 +733,16 @@ TEST(HexFloatOperationTests, TEST(HexFloatOperationTests, NonRounding) { // Rounding from 32-bit hex-float to 32-bit hex-float should be trivial, // except in the denorm case which is a bit more complex. - using HF = spvutils::HexFloat>; + using HF = HexFloat>; bool carry_bit = false; - spvutils::round_direction rounding[] = { - spvutils::round_direction::kToZero, - spvutils::round_direction::kToNearestEven, - spvutils::round_direction::kToPositiveInfinity, - spvutils::round_direction::kToNegativeInfinity}; + round_direction rounding[] = {round_direction::kToZero, + round_direction::kToNearestEven, + round_direction::kToPositiveInfinity, + round_direction::kToNegativeInfinity}; // Everything fits, so this should be straight-forward - for (spvutils::round_direction round : rounding) { + for (round_direction round : rounding) { EXPECT_EQ(bits_set({}), HF(0.f).getRoundedNormalizedSignificand(round, &carry_bit)); EXPECT_FALSE(carry_bit); @@ -767,18 +770,18 @@ TEST(HexFloatOperationTests, NonRounding) { } } -using RD = spvutils::round_direction; +using RD = round_direction; struct RoundSignificandCase { float source_float; std::pair expected_results; - spvutils::round_direction round; + round_direction round; }; using HexFloatRoundTest = ::testing::TestWithParam; TEST_P(HexFloatRoundTest, RoundDownToFP16) { - using HF = spvutils::HexFloat>; - using HF16 = spvutils::HexFloat>; + using HF = HexFloat>; + using HF16 = HexFloat>; HF input_value(GetParam().source_float); bool carry_bit = false; @@ -846,18 +849,17 @@ struct UpCastSignificandCase { using HexFloatRoundUpSignificandTest = ::testing::TestWithParam; TEST_P(HexFloatRoundUpSignificandTest, Widening) { - using HF = spvutils::HexFloat>; - using HF16 = spvutils::HexFloat>; + using HF = HexFloat>; + using HF16 = HexFloat>; bool carry_bit = false; - spvutils::round_direction rounding[] = { - spvutils::round_direction::kToZero, - spvutils::round_direction::kToNearestEven, - spvutils::round_direction::kToPositiveInfinity, - spvutils::round_direction::kToNegativeInfinity}; + round_direction rounding[] = {round_direction::kToZero, + round_direction::kToNearestEven, + round_direction::kToPositiveInfinity, + round_direction::kToNegativeInfinity}; // Everything fits, so everything should just be bit-shifts. - for (spvutils::round_direction round : rounding) { + for (round_direction round : rounding) { carry_bit = false; HF16 input_value(GetParam().source_half); EXPECT_EQ( @@ -884,19 +886,19 @@ INSTANTIATE_TEST_CASE_P( struct DownCastTest { float source_float; uint16_t expected_half; - std::vector directions; + std::vector directions; }; -std::string get_round_text(spvutils::round_direction direction) { +std::string get_round_text(round_direction direction) { #define CASE(round_direction) \ case round_direction: \ return #round_direction switch (direction) { - CASE(spvutils::round_direction::kToZero); - CASE(spvutils::round_direction::kToPositiveInfinity); - CASE(spvutils::round_direction::kToNegativeInfinity); - CASE(spvutils::round_direction::kToNearestEven); + CASE(round_direction::kToZero); + CASE(round_direction::kToPositiveInfinity); + CASE(round_direction::kToNegativeInfinity); + CASE(round_direction::kToNearestEven); } #undef CASE return ""; @@ -905,15 +907,15 @@ std::string get_round_text(spvutils::round_direction direction) { using HexFloatFP32To16Tests = ::testing::TestWithParam; TEST_P(HexFloatFP32To16Tests, NarrowingCasts) { - using HF = spvutils::HexFloat>; - using HF16 = spvutils::HexFloat>; + using HF = HexFloat>; + using HF16 = HexFloat>; HF f(GetParam().source_float); for (auto round : GetParam().directions) { HF16 half(0); f.castTo(half, round); EXPECT_EQ(GetParam().expected_half, half.value().getAsFloat().get_value()) << get_round_text(round) << " " << std::hex - << spvutils::BitwiseCast(GetParam().source_float) + << BitwiseCast(GetParam().source_float) << " cast to: " << half.value().getAsFloat().get_value(); } } @@ -1021,23 +1023,22 @@ struct UpCastCase { using HexFloatFP16To32Tests = ::testing::TestWithParam; TEST_P(HexFloatFP16To32Tests, WideningCasts) { - using HF = spvutils::HexFloat>; - using HF16 = spvutils::HexFloat>; + using HF = HexFloat>; + using HF16 = HexFloat>; HF16 f(GetParam().source_half); - spvutils::round_direction rounding[] = { - spvutils::round_direction::kToZero, - spvutils::round_direction::kToNearestEven, - spvutils::round_direction::kToPositiveInfinity, - spvutils::round_direction::kToNegativeInfinity}; + round_direction rounding[] = {round_direction::kToZero, + round_direction::kToNearestEven, + round_direction::kToPositiveInfinity, + round_direction::kToNegativeInfinity}; // Everything fits, so everything should just be bit-shifts. - for (spvutils::round_direction round : rounding) { + for (round_direction round : rounding) { HF flt(0.f); f.castTo(flt, round); EXPECT_EQ(GetParam().expected_float, flt.value().getAsFloat()) << get_round_text(round) << " " << std::hex - << spvutils::BitwiseCast(GetParam().source_half) + << BitwiseCast(GetParam().source_half) << " cast to: " << flt.value().getAsFloat(); } } @@ -1066,16 +1067,15 @@ INSTANTIATE_TEST_CASE_P( })), ); TEST(HexFloatOperationTests, NanTests) { - using HF = spvutils::HexFloat>; - using HF16 = spvutils::HexFloat>; - spvutils::round_direction rounding[] = { - spvutils::round_direction::kToZero, - spvutils::round_direction::kToNearestEven, - spvutils::round_direction::kToPositiveInfinity, - spvutils::round_direction::kToNegativeInfinity}; + using HF = HexFloat>; + using HF16 = HexFloat>; + round_direction rounding[] = {round_direction::kToZero, + round_direction::kToNearestEven, + round_direction::kToPositiveInfinity, + round_direction::kToNegativeInfinity}; // Everything fits, so everything should just be bit-shifts. - for (spvutils::round_direction round : rounding) { + for (round_direction round : rounding) { HF16 f16(0); HF f(0.f); HF(std::numeric_limits::quiet_NaN()).castTo(f16, round); @@ -1326,4 +1326,6 @@ TEST(FloatProxy, Lowest) { } // TODO(awoloszyn): Add fp16 tests and HexFloatTraits. -} // anonymous namespace +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/huffman_codec.cpp b/3rdparty/spirv-tools/test/huffman_codec.cpp index 3d5d293eb..58a781061 100644 --- a/3rdparty/spirv-tools/test/huffman_codec.cpp +++ b/3rdparty/spirv-tools/test/huffman_codec.cpp @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Contains utils for reading, writing and debug printing bit streams. - +#include #include #include #include #include +#include +#include #include "gmock/gmock.h" -#include "util/bit_stream.h" -#include "util/huffman_codec.h" +#include "source/comp/bit_stream.h" +#include "source/comp/huffman_codec.h" +namespace spvtools { +namespace comp { namespace { -using spvutils::BitsToStream; -using spvutils::HuffmanCodec; - const std::map& GetTestSet() { static const std::map hist = { {"a", 4}, {"e", 7}, {"f", 3}, {"h", 2}, {"i", 3}, @@ -312,4 +312,6 @@ TEST(Huffman, CreateFromTextU64) { EXPECT_EQ("00", BitsToStream(bits, num_bits)); } -} // anonymous namespace +} // namespace +} // namespace comp +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/immediate_int_test.cpp b/3rdparty/spirv-tools/test/immediate_int_test.cpp index 3e95d0c6c..393075a4e 100644 --- a/3rdparty/spirv-tools/test/immediate_int_test.cpp +++ b/3rdparty/spirv-tools/test/immediate_int_test.cpp @@ -16,18 +16,18 @@ #include #include -#include - +#include "gmock/gmock.h" #include "source/util/bitutils.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +namespace spvtools { +namespace utils { namespace { using spvtest::Concatenate; using spvtest::MakeInstruction; using spvtest::ScopedContext; using spvtest::TextToBinaryTest; -using spvutils::BitwiseCast; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::HasSubstr; @@ -286,4 +286,6 @@ TEST_F(ImmediateIntTest, NotInteger) { EXPECT_THAT(CompileFailure("!12K"), StrEq("Invalid immediate integer: !12K")); } -} // anonymous namespace +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/libspirv_macros_test.cpp b/3rdparty/spirv-tools/test/libspirv_macros_test.cpp index 5b9b54131..bf5add671 100644 --- a/3rdparty/spirv-tools/test/libspirv_macros_test.cpp +++ b/3rdparty/spirv-tools/test/libspirv_macros_test.cpp @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { TEST(Macros, BitShiftInnerParens) { ASSERT_EQ(65536, SPV_BIT(2 << 3)); } TEST(Macros, BitShiftOuterParens) { ASSERT_EQ(15, SPV_BIT(4) - 1); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/link/CMakeLists.txt b/3rdparty/spirv-tools/test/link/CMakeLists.txt index 33810aae8..06aeb9164 100644 --- a/3rdparty/spirv-tools/test/link/CMakeLists.txt +++ b/3rdparty/spirv-tools/test/link/CMakeLists.txt @@ -12,42 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_spvtools_unittest(TARGET link_binary_version - SRCS binary_version_test.cpp - LIBS SPIRV-Tools-opt SPIRV-Tools-link -) -add_spvtools_unittest(TARGET link_memory_model - SRCS memory_model_test.cpp - LIBS SPIRV-Tools-opt SPIRV-Tools-link -) - -add_spvtools_unittest(TARGET link_entry_points - SRCS entry_points_test.cpp - LIBS SPIRV-Tools-opt SPIRV-Tools-link -) - -add_spvtools_unittest(TARGET link_global_values_amount - SRCS global_values_amount_test.cpp - LIBS SPIRV-Tools-opt SPIRV-Tools-link -) - -add_spvtools_unittest(TARGET link_ids_limit - SRCS ids_limit_test.cpp - LIBS SPIRV-Tools-opt SPIRV-Tools-link -) - -add_spvtools_unittest(TARGET link_matching_imports_to_exports - SRCS matching_imports_to_exports_test.cpp - LIBS SPIRV-Tools-opt SPIRV-Tools-link -) - -add_spvtools_unittest(TARGET link_unique_ids - SRCS unique_ids_test.cpp - LIBS SPIRV-Tools-opt SPIRV-Tools-link -) - -add_spvtools_unittest(TARGET link_partial_linkage - SRCS partial_linkage_test.cpp +add_spvtools_unittest(TARGET link + SRCS + binary_version_test.cpp + entry_points_test.cpp + global_values_amount_test.cpp + ids_limit_test.cpp + matching_imports_to_exports_test.cpp + memory_model_test.cpp + partial_linkage_test.cpp + unique_ids_test.cpp LIBS SPIRV-Tools-opt SPIRV-Tools-link ) diff --git a/3rdparty/spirv-tools/test/link/binary_version_test.cpp b/3rdparty/spirv-tools/test/link/binary_version_test.cpp index 14eb91350..0ceeebae2 100644 --- a/3rdparty/spirv-tools/test/link/binary_version_test.cpp +++ b/3rdparty/spirv-tools/test/link/binary_version_test.cpp @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gmock/gmock.h" -#include "linker_fixture.h" +#include +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { namespace { using BinaryVersion = spvtest::LinkerTest; @@ -53,4 +56,5 @@ TEST_F(BinaryVersion, LinkerChoosesMaxSpirvVersion) { EXPECT_EQ(0x00000600u, linked_binary[1]); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/link/entry_points_test.cpp b/3rdparty/spirv-tools/test/link/entry_points_test.cpp index 1df6ec76c..bac8e02ef 100644 --- a/3rdparty/spirv-tools/test/link/entry_points_test.cpp +++ b/3rdparty/spirv-tools/test/link/entry_points_test.cpp @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gmock/gmock.h" -#include "linker_fixture.h" +#include +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { namespace { using ::testing::HasSubstr; @@ -87,4 +90,5 @@ OpFunctionEnd "GLCompute, was already defined.")); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/link/global_values_amount_test.cpp b/3rdparty/spirv-tools/test/link/global_values_amount_test.cpp index 79fb23c82..2c4ee1f03 100644 --- a/3rdparty/spirv-tools/test/link/global_values_amount_test.cpp +++ b/3rdparty/spirv-tools/test/link/global_values_amount_test.cpp @@ -12,18 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gmock/gmock.h" -#include "linker_fixture.h" +#include +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { namespace { using ::testing::HasSubstr; -class EntryPoints : public spvtest::LinkerTest { +class EntryPointsAmountTest : public spvtest::LinkerTest { public: - EntryPoints() { binaries.reserve(0xFFFF); } + EntryPointsAmountTest() { binaries.reserve(0xFFFF); } - virtual void SetUp() override { + void SetUp() override { binaries.push_back({SpvMagicNumber, SpvVersion, SPV_GENERATOR_CODEPLAY, @@ -100,19 +103,19 @@ class EntryPoints : public spvtest::LinkerTest { binaries.push_back(binary); } } - virtual void TearDown() override { binaries.clear(); } + void TearDown() override { binaries.clear(); } spvtest::Binaries binaries; }; -TEST_F(EntryPoints, UnderLimit) { +TEST_F(EntryPointsAmountTest, UnderLimit) { spvtest::Binary linked_binary; EXPECT_EQ(SPV_SUCCESS, Link(binaries, &linked_binary)); EXPECT_THAT(GetErrorMessage(), std::string()); } -TEST_F(EntryPoints, OverLimit) { +TEST_F(EntryPointsAmountTest, OverLimit) { binaries.push_back({SpvMagicNumber, SpvVersion, SPV_GENERATOR_CODEPLAY, @@ -146,4 +149,5 @@ TEST_F(EntryPoints, OverLimit) { "65536 global values were found.")); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/link/ids_limit_test.cpp b/3rdparty/spirv-tools/test/link/ids_limit_test.cpp index 73649cfef..6d7815a24 100644 --- a/3rdparty/spirv-tools/test/link/ids_limit_test.cpp +++ b/3rdparty/spirv-tools/test/link/ids_limit_test.cpp @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gmock/gmock.h" -#include "linker_fixture.h" +#include +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { namespace { using ::testing::HasSubstr; - using IdsLimit = spvtest::LinkerTest; TEST_F(IdsLimit, UnderLimit) { @@ -66,4 +68,5 @@ TEST_F(IdsLimit, OverLimit) { "the current ID bound.")); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/link/linker_fixture.h b/3rdparty/spirv-tools/test/link/linker_fixture.h index a1e3ec866..303f1bfd5 100644 --- a/3rdparty/spirv-tools/test/link/linker_fixture.h +++ b/3rdparty/spirv-tools/test/link/linker_fixture.h @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_TEST_LINK_LINK_TEST -#define LIBSPIRV_TEST_LINK_LINK_TEST +#ifndef TEST_LINK_LINKER_FIXTURE_H_ +#define TEST_LINK_LINKER_FIXTURE_H_ #include +#include +#include #include "source/spirv_constant.h" -#include "unit_spirv.h" - #include "spirv-tools/linker.hpp" +#include "test/unit_spirv.h" namespace spvtest { @@ -60,7 +61,7 @@ class LinkerTest : public ::testing::Test { tools_.SetMessageConsumer(consumer); } - virtual void TearDown() override { error_message_.clear(); } + void TearDown() override { error_message_.clear(); } // Assembles each of the given strings into SPIR-V binaries before linking // them together. SPV_ERROR_INVALID_TEXT is returned if the assembling failed @@ -121,4 +122,4 @@ class LinkerTest : public ::testing::Test { } // namespace spvtest -#endif // LIBSPIRV_TEST_LINK_LINK_TEST +#endif // TEST_LINK_LINKER_FIXTURE_H_ diff --git a/3rdparty/spirv-tools/test/link/matching_imports_to_exports_test.cpp b/3rdparty/spirv-tools/test/link/matching_imports_to_exports_test.cpp index 2126d928b..59e62d51b 100644 --- a/3rdparty/spirv-tools/test/link/matching_imports_to_exports_test.cpp +++ b/3rdparty/spirv-tools/test/link/matching_imports_to_exports_test.cpp @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gmock/gmock.h" -#include "linker_fixture.h" +#include +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { namespace { using ::testing::HasSubstr; @@ -87,7 +90,7 @@ OpDecorate %1 LinkageAttributes "foo" Export )"; spvtest::Binary linked_binary; - spvtools::LinkerOptions options; + LinkerOptions options; options.SetCreateLibrary(true); EXPECT_EQ(SPV_SUCCESS, AssembleAndLink({body}, &linked_binary, options)) << GetErrorMessage(); @@ -396,4 +399,5 @@ OpFunctionEnd EXPECT_EQ(expected_res, res_body); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/link/memory_model_test.cpp b/3rdparty/spirv-tools/test/link/memory_model_test.cpp index 9ab1c2273..2add5046c 100644 --- a/3rdparty/spirv-tools/test/link/memory_model_test.cpp +++ b/3rdparty/spirv-tools/test/link/memory_model_test.cpp @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gmock/gmock.h" -#include "linker_fixture.h" +#include +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { namespace { using ::testing::HasSubstr; - using MemoryModel = spvtest::LinkerTest; TEST_F(MemoryModel, Default) { @@ -68,4 +70,5 @@ OpMemoryModel Logical GLSL450 HasSubstr("Conflicting memory models: Simple vs GLSL450.")); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/link/partial_linkage_test.cpp b/3rdparty/spirv-tools/test/link/partial_linkage_test.cpp index 89142c0c8..c43b06e55 100644 --- a/3rdparty/spirv-tools/test/link/partial_linkage_test.cpp +++ b/3rdparty/spirv-tools/test/link/partial_linkage_test.cpp @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gmock/gmock.h" -#include "linker_fixture.h" +#include +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { namespace { using ::testing::HasSubstr; @@ -38,7 +41,7 @@ OpDecorate %1 LinkageAttributes "bar" Export )"; spvtest::Binary linked_binary; - spvtools::LinkerOptions linker_options; + LinkerOptions linker_options; linker_options.SetAllowPartialLinkage(true); ASSERT_EQ(SPV_SUCCESS, AssembleAndLink({body1, body2}, &linked_binary, linker_options)); @@ -82,4 +85,5 @@ OpDecorate %1 LinkageAttributes "bar" Export HasSubstr("Unresolved external reference to \"foo\".")); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/link/unique_ids_test.cpp b/3rdparty/spirv-tools/test/link/unique_ids_test.cpp index c926f0e67..55c70ea67 100644 --- a/3rdparty/spirv-tools/test/link/unique_ids_test.cpp +++ b/3rdparty/spirv-tools/test/link/unique_ids_test.cpp @@ -12,9 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gmock/gmock.h" -#include "linker_fixture.h" +#include +#include +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { namespace { using UniqueIds = spvtest::LinkerTest; @@ -128,10 +132,11 @@ TEST_F(UniqueIds, UniquelyMerged) { // clang-format on spvtest::Binary linked_binary; - spvtools::LinkerOptions options; + LinkerOptions options; options.SetVerifyIds(true); spv_result_t res = AssembleAndLink(bodies, &linked_binary, options); EXPECT_EQ(SPV_SUCCESS, res); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/log_test.cpp b/3rdparty/spirv-tools/test/log_test.cpp index f14f77d82..ec66aa1ec 100644 --- a/3rdparty/spirv-tools/test/log_test.cpp +++ b/3rdparty/spirv-tools/test/log_test.cpp @@ -12,15 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include - -#include "message.h" -#include "opt/log.h" +#include "source/opt/log.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +namespace spvtools { namespace { -using namespace spvtools; using ::testing::MatchesRegex; TEST(Log, Unimplemented) { @@ -51,4 +49,5 @@ TEST(Log, Unreachable) { EXPECT_EQ(1, invocation); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/move_to_front_test.cpp b/3rdparty/spirv-tools/test/move_to_front_test.cpp index 08c29212a..c95d38656 100644 --- a/3rdparty/spirv-tools/test/move_to_front_test.cpp +++ b/3rdparty/spirv-tools/test/move_to_front_test.cpp @@ -15,17 +15,18 @@ #include #include #include +#include +#include #include "gmock/gmock.h" -#include "util/move_to_front.h" +#include "source/comp/move_to_front.h" +namespace spvtools { +namespace comp { namespace { -using spvutils::MoveToFront; -using spvutils::MultiMoveToFront; - // Class used to test the inner workings of MoveToFront. -class MoveToFrontTester : public MoveToFront { +class MoveToFrontTester : public MoveToFront { public: // Inserts the value in the internal tree data structure. For testing only. void TestInsert(uint32_t val) { InsertNode(CreateNode(val, val)); } @@ -822,120 +823,6 @@ TEST(MoveToFront, LargerScale) { ASSERT_EQ(1000u, value); } -TEST(MoveToFront, String) { - MoveToFront mtf; - - EXPECT_TRUE(mtf.Insert("AAA")); - EXPECT_TRUE(mtf.Insert("BBB")); - EXPECT_TRUE(mtf.Insert("CCC")); - EXPECT_FALSE(mtf.Insert("AAA")); - - EXPECT_TRUE(mtf.HasValue("AAA")); - EXPECT_FALSE(mtf.HasValue("DDD")); - - std::string value; - EXPECT_TRUE(mtf.ValueFromRank(2, &value)); - EXPECT_EQ("BBB", value); - - EXPECT_TRUE(mtf.ValueFromRank(2, &value)); - EXPECT_EQ("CCC", value); - - uint32_t rank = 0; - EXPECT_TRUE(mtf.RankFromValue("AAA", &rank)); - EXPECT_EQ(3u, rank); - - EXPECT_FALSE(mtf.ValueFromRank(0, &value)); - EXPECT_FALSE(mtf.RankFromValue("ABC", &rank)); - EXPECT_FALSE(mtf.Remove("ABC")); - - EXPECT_TRUE(mtf.Remove("AAA")); - EXPECT_FALSE(mtf.Remove("AAA")); - EXPECT_FALSE(mtf.RankFromValue("AAA", &rank)); - - EXPECT_TRUE(mtf.Insert("AAA")); - EXPECT_TRUE(mtf.RankFromValue("AAA", &rank)); - EXPECT_EQ(1u, rank); - - EXPECT_TRUE(mtf.Promote("BBB")); - EXPECT_TRUE(mtf.RankFromValue("BBB", &rank)); - EXPECT_EQ(1u, rank); -} - -TEST(MultiMoveToFront, Empty) { - MultiMoveToFront multi_mtf; - - uint32_t rank = 0; - std::string value; - - EXPECT_EQ(0u, multi_mtf.GetSize(1001)); - EXPECT_FALSE(multi_mtf.RankFromValue(1001, "AAA", &rank)); - EXPECT_FALSE(multi_mtf.ValueFromRank(1001, 1, &value)); - EXPECT_FALSE(multi_mtf.HasValue(1001, "AAA")); - EXPECT_FALSE(multi_mtf.Remove(1001, "AAA")); -} - -TEST(MultiMoveToFront, TwoSequences) { - MultiMoveToFront multi_mtf; - - uint32_t rank = 0; - std::string value; - - EXPECT_TRUE(multi_mtf.Insert(1001, "AAA")); - - EXPECT_EQ(1u, multi_mtf.GetSize(1001)); - EXPECT_EQ(0u, multi_mtf.GetSize(1002)); - EXPECT_TRUE(multi_mtf.HasValue(1001, "AAA")); - EXPECT_FALSE(multi_mtf.HasValue(1002, "AAA")); - - EXPECT_TRUE(multi_mtf.RankFromValue(1001, "AAA", &rank)); - EXPECT_EQ(1u, rank); - EXPECT_FALSE(multi_mtf.RankFromValue(1002, "AAA", &rank)); - - EXPECT_TRUE(multi_mtf.ValueFromRank(1001, rank, &value)); - EXPECT_EQ("AAA", value); - EXPECT_FALSE(multi_mtf.ValueFromRank(1002, rank, &value)); - - EXPECT_TRUE(multi_mtf.Insert(1001, "BBB")); - - EXPECT_EQ(2u, multi_mtf.GetSize(1001)); - EXPECT_EQ(0u, multi_mtf.GetSize(1002)); - EXPECT_TRUE(multi_mtf.HasValue(1001, "BBB")); - EXPECT_FALSE(multi_mtf.HasValue(1002, "BBB")); - - EXPECT_TRUE(multi_mtf.RankFromValue(1001, "BBB", &rank)); - EXPECT_EQ(1u, rank); - EXPECT_FALSE(multi_mtf.RankFromValue(1002, "BBB", &rank)); - - EXPECT_TRUE(multi_mtf.ValueFromRank(1001, rank, &value)); - EXPECT_EQ("BBB", value); - EXPECT_FALSE(multi_mtf.ValueFromRank(1002, rank, &value)); - - EXPECT_TRUE(multi_mtf.Insert(1002, "AAA")); - - EXPECT_EQ(2u, multi_mtf.GetSize(1001)); - EXPECT_EQ(1u, multi_mtf.GetSize(1002)); - EXPECT_TRUE(multi_mtf.HasValue(1002, "AAA")); - - EXPECT_TRUE(multi_mtf.RankFromValue(1002, "AAA", &rank)); - EXPECT_EQ(1u, rank); - - EXPECT_TRUE(multi_mtf.RankFromValue(1001, "AAA", &rank)); - EXPECT_EQ(2u, rank); - - multi_mtf.Promote("BBB"); - - EXPECT_TRUE(multi_mtf.RankFromValue(1001, "BBB", &rank)); - EXPECT_EQ(1u, rank); - - EXPECT_TRUE(multi_mtf.Insert(1002, "CCC")); - EXPECT_TRUE(multi_mtf.RankFromValue(1002, "CCC", &rank)); - EXPECT_EQ(1u, rank); - - multi_mtf.Promote("AAA"); - EXPECT_TRUE(multi_mtf.RankFromValue(1001, "AAA", &rank)); - EXPECT_EQ(1u, rank); - EXPECT_TRUE(multi_mtf.RankFromValue(1002, "AAA", &rank)); - EXPECT_EQ(1u, rank); -} - -} // anonymous namespace +} // namespace +} // namespace comp +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/name_mapper_test.cpp b/3rdparty/spirv-tools/test/name_mapper_test.cpp index d2a5f4e4d..9a9ee8aa0 100644 --- a/3rdparty/spirv-tools/test/name_mapper_test.cpp +++ b/3rdparty/spirv-tools/test/name_mapper_test.cpp @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include + #include "gmock/gmock.h" - -#include "test_fixture.h" -#include "unit_spirv.h" - #include "source/name_mapper.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { -using libspirv::FriendlyNameMapper; -using libspirv::NameMapper; using spvtest::ScopedContext; using ::testing::Eq; -namespace { - TEST(TrivialNameTest, Samples) { - auto mapper = libspirv::GetTrivialNameMapper(); + auto mapper = GetTrivialNameMapper(); EXPECT_EQ(mapper(1), "1"); EXPECT_EQ(mapper(1999), "1999"); EXPECT_EQ(mapper(1024), "1024"); @@ -343,4 +343,5 @@ INSTANTIATE_TEST_CASE_P( {"%1 = OpTypeBool\n%2 = OpConstantFalse %1", 2, "false"}, }), ); -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/named_id_test.cpp b/3rdparty/spirv-tools/test/named_id_test.cpp index f83f5e864..4ba54adc3 100644 --- a/3rdparty/spirv-tools/test/named_id_test.cpp +++ b/3rdparty/spirv-tools/test/named_id_test.cpp @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include -#include "test_fixture.h" -#include "unit_spirv.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using NamedIdTest = spvtest::TextToBinaryTest; @@ -81,4 +83,5 @@ INSTANTIATE_TEST_CASE_P( {"5", false}, {"32", false}, {"foo", false}, {"a%bar", false}})), ); -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opcode_make_test.cpp b/3rdparty/spirv-tools/test/opcode_make_test.cpp index 5353b7190..6481ef326 100644 --- a/3rdparty/spirv-tools/test/opcode_make_test.cpp +++ b/3rdparty/spirv-tools/test/opcode_make_test.cpp @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { // A sampling of word counts. Covers extreme points well, and all bit @@ -39,4 +40,5 @@ TEST(OpcodeMake, Samples) { } } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opcode_require_capabilities_test.cpp b/3rdparty/spirv-tools/test/opcode_require_capabilities_test.cpp index 2aa1b86e9..32bf1dc08 100644 --- a/3rdparty/spirv-tools/test/opcode_require_capabilities_test.cpp +++ b/3rdparty/spirv-tools/test/opcode_require_capabilities_test.cpp @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include "test/unit_spirv.h" -#include "enum_set.h" +#include "source/enum_set.h" +namespace spvtools { namespace { -using libspirv::CapabilitySet; using spvtest::ElementsIn; // Capabilities required by an Opcode. @@ -74,4 +74,5 @@ INSTANTIATE_TEST_CASE_P( SpvOpGetKernelMaxNumSubgroups, CapabilitySet{SpvCapabilitySubgroupDispatch}}), ); -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opcode_split_test.cpp b/3rdparty/spirv-tools/test/opcode_split_test.cpp index f42d903ae..43fedb385 100644 --- a/3rdparty/spirv-tools/test/opcode_split_test.cpp +++ b/3rdparty/spirv-tools/test/opcode_split_test.cpp @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { TEST(OpcodeSplit, Default) { @@ -25,4 +26,5 @@ TEST(OpcodeSplit, Default) { ASSERT_EQ(23, opcode); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opcode_table_get_test.cpp b/3rdparty/spirv-tools/test/opcode_table_get_test.cpp index f2272f076..6f80ad7d8 100644 --- a/3rdparty/spirv-tools/test/opcode_table_get_test.cpp +++ b/3rdparty/spirv-tools/test/opcode_table_get_test.cpp @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - -#include "unit_spirv.h" +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using GetTargetOpcodeTableGetTest = ::testing::TestWithParam; @@ -35,4 +35,5 @@ TEST_P(GetTargetOpcodeTableGetTest, InvalidPointerTable) { INSTANTIATE_TEST_CASE_P(OpcodeTableGet, GetTargetOpcodeTableGetTest, ValuesIn(spvtest::AllTargetEnvironments())); -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/operand_capabilities_test.cpp b/3rdparty/spirv-tools/test/operand_capabilities_test.cpp index ddc60e267..0aeb505f0 100644 --- a/3rdparty/spirv-tools/test/operand_capabilities_test.cpp +++ b/3rdparty/spirv-tools/test/operand_capabilities_test.cpp @@ -15,18 +15,16 @@ // Test capability dependencies for enums. #include +#include #include "gmock/gmock.h" +#include "source/enum_set.h" +#include "test/unit_spirv.h" -#include "enum_set.h" -#include "unit_spirv.h" - +namespace spvtools { namespace { -using libspirv::CapabilitySet; using spvtest::ElementsIn; -using std::get; -using std::tuple; using ::testing::Combine; using ::testing::Eq; using ::testing::TestWithParam; @@ -42,23 +40,24 @@ struct EnumCapabilityCase { // Test fixture for testing EnumCapabilityCases. using EnumCapabilityTest = - TestWithParam>; + TestWithParam>; TEST_P(EnumCapabilityTest, Sample) { - const auto env = get<0>(GetParam()); + const auto env = std::get<0>(GetParam()); const auto context = spvContextCreate(env); - const libspirv::AssemblyGrammar grammar(context); + const AssemblyGrammar grammar(context); spv_operand_desc entry; ASSERT_EQ(SPV_SUCCESS, - grammar.lookupOperand(get<1>(GetParam()).type, - get<1>(GetParam()).value, &entry)); + grammar.lookupOperand(std::get<1>(GetParam()).type, + std::get<1>(GetParam()).value, &entry)); const auto cap_set = grammar.filterCapsAgainstTargetEnv( entry->capabilities, entry->numCapabilities); EXPECT_THAT(ElementsIn(cap_set), - Eq(ElementsIn(get<1>(GetParam()).expected_capabilities))) - << " capability value " << get<1>(GetParam()).value; + Eq(ElementsIn(std::get<1>(GetParam()).expected_capabilities))) + << " capability value " << std::get<1>(GetParam()).value; + spvContextDestroy(context); } #define CASE0(TYPE, VALUE) \ @@ -77,6 +76,12 @@ TEST_P(EnumCapabilityTest, Sample) { SpvCapability##CAP1, SpvCapability##CAP2 \ } \ } +#define CASE3(TYPE, VALUE, CAP1, CAP2, CAP3) \ + { \ + SPV_OPERAND_TYPE_##TYPE, uint32_t(Spv##VALUE), CapabilitySet { \ + SpvCapability##CAP1, SpvCapability##CAP2, SpvCapability##CAP3 \ + } \ + } #define CASE5(TYPE, VALUE, CAP1, CAP2, CAP3, CAP4, CAP5) \ { \ SPV_OPERAND_TYPE_##TYPE, uint32_t(Spv##VALUE), CapabilitySet { \ @@ -348,7 +353,7 @@ INSTANTIATE_TEST_CASE_P( CASE0(OPTIONAL_IMAGE, ImageOperandsGradMask), CASE0(OPTIONAL_IMAGE, ImageOperandsConstOffsetMask), CASE1(OPTIONAL_IMAGE, ImageOperandsOffsetMask, ImageGatherExtended), - CASE0(OPTIONAL_IMAGE, ImageOperandsConstOffsetsMask), + CASE1(OPTIONAL_IMAGE, ImageOperandsConstOffsetsMask, ImageGatherExtended), CASE0(OPTIONAL_IMAGE, ImageOperandsSampleMask), CASE1(OPTIONAL_IMAGE, ImageOperandsMinLodMask, MinLod), // clang-format on @@ -615,9 +620,12 @@ INSTANTIATE_TEST_CASE_P( GroupOperation, EnumCapabilityTest, Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), ValuesIn(std::vector{ - CASE1(GROUP_OPERATION, GroupOperationReduce, Kernel), - CASE1(GROUP_OPERATION, GroupOperationInclusiveScan, Kernel), - CASE1(GROUP_OPERATION, GroupOperationExclusiveScan, Kernel), + CASE3(GROUP_OPERATION, GroupOperationReduce, Kernel, + GroupNonUniformArithmetic, GroupNonUniformBallot), + CASE3(GROUP_OPERATION, GroupOperationInclusiveScan, Kernel, + GroupNonUniformArithmetic, GroupNonUniformBallot), + CASE3(GROUP_OPERATION, GroupOperationExclusiveScan, Kernel, + GroupNonUniformArithmetic, GroupNonUniformBallot), })), ); // See SPIR-V Section 3.29 Kernel Enqueue Flags @@ -687,7 +695,7 @@ INSTANTIATE_TEST_CASE_P( CASE1(CAPABILITY, CapabilityImageRect, SampledRect), CASE1(CAPABILITY, CapabilitySampledRect, Shader), CASE1(CAPABILITY, CapabilityGenericPointer, Addresses), - CASE1(CAPABILITY, CapabilityInt8, Kernel), + CASE0(CAPABILITY, CapabilityInt8), CASE1(CAPABILITY, CapabilityInputAttachment, Shader), CASE1(CAPABILITY, CapabilitySparseResidency, Shader), CASE1(CAPABILITY, CapabilityMinLod, Shader), @@ -720,4 +728,5 @@ INSTANTIATE_TEST_CASE_P( #undef CASE1 #undef CASE2 -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/operand_pattern_test.cpp b/3rdparty/spirv-tools/test/operand_pattern_test.cpp index 8f35fcf1c..b3e302490 100644 --- a/3rdparty/spirv-tools/test/operand_pattern_test.cpp +++ b/3rdparty/spirv-tools/test/operand_pattern_test.cpp @@ -12,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include #include "gmock/gmock.h" #include "source/operand.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { using ::testing::Eq; -namespace { - TEST(OperandPattern, InitiallyEmpty) { spv_operand_pattern_t empty; EXPECT_THAT(empty, Eq(spv_operand_pattern_t{})); @@ -262,4 +264,5 @@ TEST(AlternatePatternFollowingImmediate, ResultIdBack) { SPV_OPERAND_TYPE_RESULT_ID})); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/operand_test.cpp b/3rdparty/spirv-tools/test/operand_test.cpp index 464d38fcc..08522c323 100644 --- a/3rdparty/spirv-tools/test/operand_test.cpp +++ b/3rdparty/spirv-tools/test/operand_test.cpp @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include +#include "test/unit_spirv.h" + +namespace spvtools { namespace { using GetTargetTest = ::testing::TestWithParam; -using std::vector; using ::testing::ValuesIn; TEST_P(GetTargetTest, Default) { @@ -32,9 +34,9 @@ TEST_P(GetTargetTest, InvalidPointerTable) { } INSTANTIATE_TEST_CASE_P(OperandTableGet, GetTargetTest, - ValuesIn(vector{SPV_ENV_UNIVERSAL_1_0, - SPV_ENV_UNIVERSAL_1_1, - SPV_ENV_VULKAN_1_0}), ); + ValuesIn(std::vector{ + SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_VULKAN_1_0}), ); TEST(OperandString, AllAreDefinedExceptVariable) { // None has no string, so don't test it. @@ -69,4 +71,5 @@ TEST(OperandIsConcreteMask, Sample) { spvOperandIsConcreteMask(SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/CMakeLists.txt b/3rdparty/spirv-tools/test/opt/CMakeLists.txt index de1cd16d3..f2741f673 100644 --- a/3rdparty/spirv-tools/test/opt/CMakeLists.txt +++ b/3rdparty/spirv-tools/test/opt/CMakeLists.txt @@ -297,10 +297,16 @@ add_spvtools_unittest(TARGET replace_invalid_opc LIBS SPIRV-Tools-opt ) +add_spvtools_unittest(TARGET register_liveness + SRCS register_liveness.cpp + LIBS SPIRV-Tools-opt +) + add_spvtools_unittest(TARGET simplification SRCS simplification_test.cpp pass_utils.cpp LIBS SPIRV-Tools-opt ) + add_spvtools_unittest(TARGET copy_prop_array SRCS copy_prop_array_test.cpp pass_utils.cpp LIBS SPIRV-Tools-opt @@ -311,3 +317,23 @@ add_spvtools_unittest(TARGET scalar_analysis LIBS SPIRV-Tools-opt ) +add_spvtools_unittest(TARGET vector_dce + SRCS vector_dce_test.cpp pass_utils.cpp + LIBS SPIRV-Tools-opt +) + +add_spvtools_unittest(TARGET reduce_load_size + SRCS reduce_load_size_test.cpp pass_utils.cpp + LIBS SPIRV-Tools-opt +) + +add_spvtools_unittest(TARGET constant_manager + SRCS constant_manager_test.cpp + LIBS SPIRV-Tools-opt +) + +add_spvtools_unittest(TARGET combine_access_chains + SRCS combine_access_chains_test.cpp + LIBS SPIRV-Tools-opt +) + diff --git a/3rdparty/spirv-tools/test/opt/aggressive_dead_code_elim_test.cpp b/3rdparty/spirv-tools/test/opt/aggressive_dead_code_elim_test.cpp index b3346dbff..287fcef0f 100644 --- a/3rdparty/spirv-tools/test/opt/aggressive_dead_code_elim_test.cpp +++ b/3rdparty/spirv-tools/test/opt/aggressive_dead_code_elim_test.cpp @@ -13,14 +13,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "assembly_builder.h" -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using AggressiveDCETest = PassTest<::testing::Test>; TEST_F(AggressiveDCETest, EliminateExtendedInst) { @@ -103,7 +106,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs1 + names_before + predefs2 + func_before, predefs1 + names_after + predefs2 + func_after, true, true); } @@ -219,7 +222,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs1 + names_before + predefs2_before + func_before, predefs1 + names_after + predefs2_after + func_after, true, true); } @@ -321,7 +324,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs1 + names_before + predefs2_before + func_before, predefs1 + names_after + predefs2_after + func_after, true, true); } @@ -405,7 +408,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs1 + names_before + predefs2 + func_before, predefs1 + names_after + predefs2 + func_after, true, true); } @@ -490,7 +493,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs1 + names_before + predefs2 + func_before, predefs1 + names_after + predefs2 + func_after, true, true); } @@ -546,7 +549,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, ElimWithCall) { @@ -672,8 +675,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - defs_before + func_before, defs_after + func_after, true, true); + SinglePassRunAndCheck(defs_before + func_before, + defs_after + func_after, true, true); } TEST_F(AggressiveDCETest, NoParamElim) { @@ -802,8 +805,8 @@ OpReturnValue %27 OpFunctionEnd )"; - SinglePassRunAndCheck( - defs_before + func_before, defs_after + func_after, true, true); + SinglePassRunAndCheck(defs_before + func_before, + defs_after + func_after, true, true); } TEST_F(AggressiveDCETest, ElimOpaque) { @@ -904,8 +907,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - defs_before + func_before, defs_after + func_after, true, true); + SinglePassRunAndCheck(defs_before + func_before, + defs_after + func_after, true, true); } TEST_F(AggressiveDCETest, NoParamStoreElim) { @@ -975,7 +978,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, PrivateStoreElimInEntryNoCalls) { @@ -1080,7 +1083,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs_before + main_before, predefs_after + main_after, true, true); } @@ -1135,7 +1138,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, NoPrivateStoreElimWithCall) { @@ -1200,7 +1203,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, NoPrivateStoreElimInNonEntry) { @@ -1265,7 +1268,113 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, WorkgroupStoreElimInEntryNoCalls) { + // Eliminate stores to private in entry point with no calls + // Note: Not legal GLSL + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 1) in vec4 Dead; + // layout(location = 0) out vec4 OutColor; + // + // workgroup vec4 dv; + // + // void main() + // { + // vec4 v = BaseColor; + // dv = Dead; + // OutColor = v; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Dead %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %dv "dv" +OpName %Dead "Dead" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %Dead Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Workgroup_v4float = OpTypePointer Workgroup %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%dv = OpVariable %_ptr_Workgroup_v4float Workgroup +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Dead %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %Dead "Dead" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %Dead Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string main_before = + R"(%main = OpFunction %void None %9 +%16 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%17 = OpLoad %v4float %BaseColor +OpStore %v %17 +%18 = OpLoad %v4float %Dead +OpStore %dv %18 +%19 = OpLoad %v4float %v +%20 = OpFNegate %v4float %19 +OpStore %OutColor %20 +OpReturn +OpFunctionEnd +)"; + + const std::string main_after = + R"(%main = OpFunction %void None %9 +%16 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%17 = OpLoad %v4float %BaseColor +OpStore %v %17 +%19 = OpLoad %v4float %v +%20 = OpFNegate %v4float %19 +OpStore %OutColor %20 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs_before + main_before, predefs_after + main_after, true, true); } TEST_F(AggressiveDCETest, EliminateDeadIfThenElse) { @@ -1376,7 +1485,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs_before + func_before, predefs_after + func_after, true, true); } @@ -1480,7 +1589,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs_before + func_before, predefs_after + func_after, true, true); } @@ -1587,7 +1696,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, after, true, true); + SinglePassRunAndCheck(before, after, true, true); } TEST_F(AggressiveDCETest, EliminateDeadIfThenElseNested) { @@ -1724,7 +1833,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs_before + func_before, predefs_after + func_after, true, true); } @@ -1799,7 +1908,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, NoEliminateLiveIfThenElseNested) { @@ -1899,7 +2008,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, NoEliminateIfWithPhi) { @@ -1965,7 +2074,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, NoEliminateIfBreak) { @@ -2046,7 +2155,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, NoEliminateIfBreak2) { @@ -2144,7 +2253,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, EliminateEntireUselessLoop) { @@ -2288,7 +2397,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs1 + names_before + predefs2_before + func_before, predefs1 + names_after + predefs2_after + func_after, true, true); } @@ -2368,7 +2477,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, NoEliminateLiveLoop) { @@ -2451,7 +2560,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, EliminateEntireFunctionBody) { @@ -2555,7 +2664,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs_before + func_before, predefs_after + func_after, true, true); } @@ -2755,7 +2864,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs_before + func_before, predefs_after + func_after, true, true); } @@ -2915,7 +3024,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs_before + func_before, predefs_after + func_after, true, true); } @@ -3038,7 +3147,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs_before + func_before, predefs_after + func_after, true, true); } @@ -3168,7 +3277,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, NoEliminateIfContinue) { @@ -3275,7 +3384,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, NoEliminateIfContinue2) { @@ -3379,7 +3488,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, NoEliminateIfContinue3) { @@ -3485,7 +3594,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(AggressiveDCETest, PointerVariable) { @@ -3584,7 +3693,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before, after, true, true); + SinglePassRunAndCheck(before, after, true, true); } // %dead is unused. Make sure we remove it along with its name. @@ -3628,7 +3737,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, after, true, true); + SinglePassRunAndCheck(before, after, true, true); } // Delete %dead because it is unreferenced. Then %initializer becomes @@ -3675,7 +3784,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, after, true, true); + SinglePassRunAndCheck(before, after, true, true); } // Keep %live because it is used, and its initializer. @@ -3709,7 +3818,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, before, true, true); + SinglePassRunAndCheck(before, before, true, true); } // This test that the decoration associated with a variable are removed when the @@ -3761,7 +3870,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, after, true, true); + SinglePassRunAndCheck(before, after, true, true); } #ifdef SPIRV_EFFCEE @@ -3808,14 +3917,14 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } #endif // SPIRV_EFFCEE TEST_F(AggressiveDCETest, LiveNestedSwitch) { const std::string text = R"(OpCapability Shader OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %func "func" %3 +OpEntryPoint Fragment %func "func" %3 %10 OpExecutionMode %func OriginUpperLeft OpName %func "func" %void = OpTypeVoid @@ -3848,7 +3957,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(text, text, false, true); + SinglePassRunAndCheck(text, text, false, true); } TEST_F(AggressiveDCETest, BasicDeleteDeadFunction) { @@ -3885,7 +3994,7 @@ TEST_F(AggressiveDCETest, BasicDeleteDeadFunction) { }; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(Concat(common_code, dead_function)), JoinAllInsts(common_code), /* skip_nop = */ true); } @@ -3922,9 +4031,9 @@ TEST_F(AggressiveDCETest, BasicKeepLiveFunction) { SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); std::string assembly = JoinAllInsts(text); - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( assembly, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); EXPECT_EQ(assembly, std::get<0>(result)); } @@ -3982,8 +4091,8 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(text, expected_output, - /* skip_nop = */ true); + SinglePassRunAndCheck(text, expected_output, + /* skip_nop = */ true); } #ifdef SPIRV_EFFCEE @@ -4015,7 +4124,7 @@ TEST_F(AggressiveDCETest, BasicAllDeadConstants) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } #endif // SPIRV_EFFCEE @@ -4071,7 +4180,7 @@ TEST_F(AggressiveDCETest, BasicNoneDeadConstants) { // clang-format on }; // All constants are used, so none of them should be eliminated. - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(text), JoinAllInsts(text), /* skip_nop = */ true); } @@ -4140,8 +4249,7 @@ TEST_P(EliminateDeadConstantTest, Custom) { // Do not enable validation. As the input code is invalid from the base // tests (ported from other passes). - SinglePassRunAndMatch(assembly_with_dead_const, - false); + SinglePassRunAndMatch(assembly_with_dead_const, false); } INSTANTIATE_TEST_CASE_P( @@ -4998,7 +5106,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(AggressiveDCETest, ParitallyDeadDecorationGroup) { @@ -5032,7 +5140,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(AggressiveDCETest, ParitallyDeadDecorationGroupDifferentGroupDecorate) { @@ -5068,7 +5176,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(AggressiveDCETest, DeadGroupMemberDecorate) { @@ -5095,7 +5203,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(AggressiveDCETest, PartiallyDeadGroupMemberDecorate) { @@ -5133,7 +5241,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(AggressiveDCETest, @@ -5174,7 +5282,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } // Test for #1404 @@ -5199,7 +5307,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } #endif // SPIRV_EFFCEE @@ -5245,7 +5353,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(text, text, true, true); + SinglePassRunAndCheck(text, text, true, true); } TEST_F(AggressiveDCETest, BreaksDontVisitPhis) { @@ -5286,8 +5394,8 @@ OpReturn OpFunctionEnd )"; - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, - std::get<1>(SinglePassRunAndDisassemble( + EXPECT_EQ(Pass::Status::SuccessWithoutChange, + std::get<1>(SinglePassRunAndDisassemble( text, false, true))); } @@ -5326,7 +5434,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(text, text, true, true); + SinglePassRunAndCheck(text, text, true, true); } // Test for #1212 @@ -5366,7 +5474,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(text, text, true, true); + SinglePassRunAndCheck(text, text, true, true); } TEST_F(AggressiveDCETest, AtomicAdd) { @@ -5407,13 +5515,309 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(text, text, true, true); + SinglePassRunAndCheck(text, text, true, true); } +TEST_F(AggressiveDCETest, SafelyRemoveDecorateString) { + const std::string preamble = R"(OpCapability Shader +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" +)"; + + const std::string body_before = + R"(OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "FOOBAR" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint +%2 = OpVariable %_ptr_StorageBuffer_uint StorageBuffer +%1 = OpFunction %void None %4 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string body_after = R"(%void = OpTypeVoid +%4 = OpTypeFunction %void +%1 = OpFunction %void None %4 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(preamble + body_before, + preamble + body_after, true, true); +} + +TEST_F(AggressiveDCETest, CopyMemoryToGlobal) { + // |local| is loaded in an OpCopyMemory instruction. So the store must be + // kept alive. + const std::string test = + R"(OpCapability Geometry +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" %global +OpExecutionMode %main Triangles +OpExecutionMode %main Invocations 1 +OpExecutionMode %main OutputTriangleStrip +OpExecutionMode %main OutputVertices 5 +OpSource GLSL 440 +OpName %main "main" +OpName %local "local" +OpName %global "global" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%12 = OpConstantNull %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%global = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%19 = OpLabel +%local = OpVariable %_ptr_Function_v4float Function +OpStore %local %12 +OpCopyMemory %global %local +OpEndPrimitive +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, test, true, true); +} + +TEST_F(AggressiveDCETest, CopyMemoryToLocal) { + // Make sure the store to |local2| using OpCopyMemory is kept and keeps + // |local1| alive. + const std::string test = + R"(OpCapability Geometry +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" %global +OpExecutionMode %main Triangles +OpExecutionMode %main Invocations 1 +OpExecutionMode %main OutputTriangleStrip +OpExecutionMode %main OutputVertices 5 +OpSource GLSL 440 +OpName %main "main" +OpName %local1 "local1" +OpName %local2 "local2" +OpName %global "global" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%12 = OpConstantNull %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%global = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%19 = OpLabel +%local1 = OpVariable %_ptr_Function_v4float Function +%local2 = OpVariable %_ptr_Function_v4float Function +OpStore %local1 %12 +OpCopyMemory %local2 %local1 +OpCopyMemory %global %local2 +OpEndPrimitive +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, test, true, true); +} + +TEST_F(AggressiveDCETest, RemoveCopyMemoryToLocal) { + // Test that we remove function scope variables that are stored to using + // OpCopyMemory, but are never loaded. We can remove both |local1| and + // |local2|. + const std::string test = + R"(OpCapability Geometry +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" %global +OpExecutionMode %main Triangles +OpExecutionMode %main Invocations 1 +OpExecutionMode %main OutputTriangleStrip +OpExecutionMode %main OutputVertices 5 +OpSource GLSL 440 +OpName %main "main" +OpName %local1 "local1" +OpName %local2 "local2" +OpName %global "global" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%12 = OpConstantNull %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%global = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%19 = OpLabel +%local1 = OpVariable %_ptr_Function_v4float Function +%local2 = OpVariable %_ptr_Function_v4float Function +OpStore %local1 %12 +OpCopyMemory %local2 %local1 +OpEndPrimitive +OpReturn +OpFunctionEnd +)"; + + const std::string result = + R"(OpCapability Geometry +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" %global +OpExecutionMode %main Triangles +OpExecutionMode %main Invocations 1 +OpExecutionMode %main OutputTriangleStrip +OpExecutionMode %main OutputVertices 5 +OpSource GLSL 440 +OpName %main "main" +OpName %global "global" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%global = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%19 = OpLabel +OpEndPrimitive +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, result, true, true); +} + +TEST_F(AggressiveDCETest, RemoveCopyMemoryToLocal2) { + // We are able to remove "local2" because it is not loaded, but have to keep + // the stores to "local1". + const std::string test = + R"(OpCapability Geometry +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" %global +OpExecutionMode %main Triangles +OpExecutionMode %main Invocations 1 +OpExecutionMode %main OutputTriangleStrip +OpExecutionMode %main OutputVertices 5 +OpSource GLSL 440 +OpName %main "main" +OpName %local1 "local1" +OpName %local2 "local2" +OpName %global "global" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%12 = OpConstantNull %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%global = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%19 = OpLabel +%local1 = OpVariable %_ptr_Function_v4float Function +%local2 = OpVariable %_ptr_Function_v4float Function +OpStore %local1 %12 +OpCopyMemory %local2 %local1 +OpCopyMemory %global %local1 +OpEndPrimitive +OpReturn +OpFunctionEnd +)"; + + const std::string result = + R"(OpCapability Geometry +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" %global +OpExecutionMode %main Triangles +OpExecutionMode %main Invocations 1 +OpExecutionMode %main OutputTriangleStrip +OpExecutionMode %main OutputVertices 5 +OpSource GLSL 440 +OpName %main "main" +OpName %local1 "local1" +OpName %global "global" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%12 = OpConstantNull %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%global = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%19 = OpLabel +%local1 = OpVariable %_ptr_Function_v4float Function +OpStore %local1 %12 +OpCopyMemory %global %local1 +OpEndPrimitive +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, result, true, true); +} + +TEST_F(AggressiveDCETest, StructuredIfWithConditionalExit) { + // We are able to remove "local2" because it is not loaded, but have to keep + // the stores to "local1". + const std::string test = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" +OpSourceExtension "GL_GOOGLE_include_directive" +OpName %main "main" +OpName %a "a" +%void = OpTypeVoid +%5 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Uniform_int = OpTypePointer Uniform %int +%int_0 = OpConstant %int 0 +%bool = OpTypeBool +%int_100 = OpConstant %int 100 +%int_1 = OpConstant %int 1 +%a = OpVariable %_ptr_Uniform_int Uniform +%main = OpFunction %void None %5 +%12 = OpLabel +%13 = OpLoad %int %a +%14 = OpSGreaterThan %bool %13 %int_0 +OpSelectionMerge %15 None +OpBranchConditional %14 %16 %15 +%16 = OpLabel +%17 = OpLoad %int %a +%18 = OpSLessThan %bool %17 %int_100 +OpBranchConditional %18 %19 %15 +%19 = OpLabel +OpStore %a %int_1 +OpBranch %15 +%15 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, test, true, true); +} // TODO(greg-lunarg): Add tests to verify handling of these cases: // // Check that logical addressing required // Check that function calls inhibit optimization // Others? -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/assembly_builder.h b/3rdparty/spirv-tools/test/opt/assembly_builder.h index 2edd4a58c..1673c092b 100644 --- a/3rdparty/spirv-tools/test/opt/assembly_builder.h +++ b/3rdparty/spirv-tools/test/opt/assembly_builder.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_TEST_OPT_ASSEMBLY_BUILDER -#define LIBSPIRV_TEST_OPT_ASSEMBLY_BUILDER +#ifndef TEST_OPT_ASSEMBLY_BUILDER_H_ +#define TEST_OPT_ASSEMBLY_BUILDER_H_ #include #include @@ -23,6 +23,7 @@ #include namespace spvtools { +namespace opt { // A simple SPIR-V assembly code builder for test uses. It builds an SPIR-V // assembly module from vectors of assembly strings. It allows users to add @@ -259,6 +260,7 @@ class AssemblyBuilder { std::unordered_set used_names_; }; +} // namespace opt } // namespace spvtools -#endif // LIBSPIRV_TEST_OPT_ASSEMBLY_BUILDER +#endif // TEST_OPT_ASSEMBLY_BUILDER_H_ diff --git a/3rdparty/spirv-tools/test/opt/assembly_builder_test.cpp b/3rdparty/spirv-tools/test/opt/assembly_builder_test.cpp index 12c796392..55fbbe904 100644 --- a/3rdparty/spirv-tools/test/opt/assembly_builder_test.cpp +++ b/3rdparty/spirv-tools/test/opt/assembly_builder_test.cpp @@ -12,14 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "assembly_builder.h" +#include "test/opt/assembly_builder.h" -#include "pass_fixture.h" -#include "pass_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using AssemblyBuilderTest = PassTest<::testing::Test>; TEST_F(AssemblyBuilderTest, MinimalShader) { @@ -44,9 +45,8 @@ TEST_F(AssemblyBuilderTest, MinimalShader) { // clang-format on }; - SinglePassRunAndCheck(builder.GetCode(), - JoinAllInsts(expected), - /* skip_nop = */ false); + SinglePassRunAndCheck(builder.GetCode(), JoinAllInsts(expected), + /* skip_nop = */ false); } TEST_F(AssemblyBuilderTest, ShaderWithConstants) { @@ -158,9 +158,8 @@ TEST_F(AssemblyBuilderTest, ShaderWithConstants) { "OpFunctionEnd", // clang-format on }; - SinglePassRunAndCheck(builder.GetCode(), - JoinAllInsts(expected), - /* skip_nop = */ false); + SinglePassRunAndCheck(builder.GetCode(), JoinAllInsts(expected), + /* skip_nop = */ false); } TEST_F(AssemblyBuilderTest, SpecConstants) { @@ -242,9 +241,8 @@ TEST_F(AssemblyBuilderTest, SpecConstants) { // clang-format on }; - SinglePassRunAndCheck(builder.GetCode(), - JoinAllInsts(expected), - /* skip_nop = */ false); + SinglePassRunAndCheck(builder.GetCode(), JoinAllInsts(expected), + /* skip_nop = */ false); } TEST_F(AssemblyBuilderTest, AppendNames) { @@ -276,9 +274,10 @@ TEST_F(AssemblyBuilderTest, AppendNames) { // clang-format on }; - SinglePassRunAndCheck(builder.GetCode(), - JoinAllInsts(expected), - /* skip_nop = */ false); + SinglePassRunAndCheck(builder.GetCode(), JoinAllInsts(expected), + /* skip_nop = */ false); } -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/block_merge_test.cpp b/3rdparty/spirv-tools/test/opt/block_merge_test.cpp index 11399e6ce..aaa70cd4a 100644 --- a/3rdparty/spirv-tools/test/opt/block_merge_test.cpp +++ b/3rdparty/spirv-tools/test/opt/block_merge_test.cpp @@ -13,13 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using BlockMergeTest = PassTest<::testing::Test>; TEST_F(BlockMergeTest, Simple) { @@ -84,8 +86,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, predefs + after, - true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, true, + true); } TEST_F(BlockMergeTest, EmptyBlock) { @@ -154,8 +156,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, predefs + after, - true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, true, + true); } TEST_F(BlockMergeTest, NestedInControlFlow) { @@ -267,8 +269,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, predefs + after, - true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, true, + true); } #ifdef SPIRV_EFFCEE @@ -306,7 +308,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(BlockMergeTest, UpdateMergeInstruction) { @@ -342,7 +344,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(BlockMergeTest, TwoMergeBlocksCannotBeMerged) { @@ -383,7 +385,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(BlockMergeTest, MergeContinue) { @@ -415,7 +417,7 @@ OpUnreachable OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(BlockMergeTest, TwoHeadersCannotBeMerged) { @@ -452,7 +454,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(BlockMergeTest, RemoveStructuredDeclaration) { @@ -516,7 +518,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(assembly, true); + SinglePassRunAndMatch(assembly, true); } TEST_F(BlockMergeTest, DontMergeKill) { @@ -548,7 +550,7 @@ OpUnreachable OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(BlockMergeTest, DontMergeUnreachable) { @@ -580,7 +582,7 @@ OpUnreachable OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(BlockMergeTest, DontMergeReturn) { @@ -612,7 +614,7 @@ OpUnreachable OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(BlockMergeTest, DontMergeSwitch) { @@ -648,7 +650,7 @@ OpUnreachable OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(BlockMergeTest, DontMergeReturnValue) { @@ -687,7 +689,7 @@ OpUnreachable OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } #endif // SPIRV_EFFCEE @@ -696,4 +698,6 @@ OpFunctionEnd // More complex control flow // Others? -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/ccp_test.cpp b/3rdparty/spirv-tools/test/opt/ccp_test.cpp index 47ae9cc00..5ccea71fb 100644 --- a/3rdparty/spirv-tools/test/opt/ccp_test.cpp +++ b/3rdparty/spirv-tools/test/opt/ccp_test.cpp @@ -12,17 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "pass_fixture.h" -#include "pass_utils.h" - -#include "opt/ccp_pass.h" +#include "source/opt/ccp_pass.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using CCPTest = PassTest<::testing::Test>; // TODO(dneto): Add Effcee as required dependency, and make this unconditional. @@ -82,7 +83,7 @@ TEST_F(CCPTest, PropagateThroughPhis) { OpFunctionEnd )"; - SinglePassRunAndMatch(spv_asm, true); + SinglePassRunAndMatch(spv_asm, true); } TEST_F(CCPTest, SimplifyConditionals) { @@ -139,7 +140,7 @@ TEST_F(CCPTest, SimplifyConditionals) { OpFunctionEnd )"; - SinglePassRunAndMatch(spv_asm, true); + SinglePassRunAndMatch(spv_asm, true); } TEST_F(CCPTest, SimplifySwitches) { @@ -188,7 +189,7 @@ TEST_F(CCPTest, SimplifySwitches) { OpFunctionEnd )"; - SinglePassRunAndMatch(spv_asm, true); + SinglePassRunAndMatch(spv_asm, true); } TEST_F(CCPTest, SimplifySwitchesDefaultBranch) { @@ -237,7 +238,7 @@ TEST_F(CCPTest, SimplifySwitchesDefaultBranch) { OpFunctionEnd )"; - SinglePassRunAndMatch(spv_asm, true); + SinglePassRunAndMatch(spv_asm, true); } TEST_F(CCPTest, SimplifyIntVector) { @@ -288,7 +289,7 @@ TEST_F(CCPTest, SimplifyIntVector) { OpFunctionEnd )"; - SinglePassRunAndMatch(spv_asm, true); + SinglePassRunAndMatch(spv_asm, true); } TEST_F(CCPTest, BadSimplifyFloatVector) { @@ -341,7 +342,7 @@ TEST_F(CCPTest, BadSimplifyFloatVector) { OpFunctionEnd )"; - SinglePassRunAndMatch(spv_asm, true); + SinglePassRunAndMatch(spv_asm, true); } TEST_F(CCPTest, NoLoadStorePropagation) { @@ -383,7 +384,7 @@ TEST_F(CCPTest, NoLoadStorePropagation) { OpFunctionEnd )"; - SinglePassRunAndMatch(spv_asm, true); + SinglePassRunAndMatch(spv_asm, true); } TEST_F(CCPTest, HandleAbortInstructions) { @@ -416,7 +417,7 @@ TEST_F(CCPTest, HandleAbortInstructions) { OpFunctionEnd )"; - SinglePassRunAndMatch(spv_asm, true); + SinglePassRunAndMatch(spv_asm, true); } TEST_F(CCPTest, SSAWebCycles) { @@ -467,7 +468,7 @@ TEST_F(CCPTest, SSAWebCycles) { )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndMatch(spv_asm, true); + SinglePassRunAndMatch(spv_asm, true); } TEST_F(CCPTest, LoopInductionVariables) { @@ -521,7 +522,7 @@ TEST_F(CCPTest, LoopInductionVariables) { OpFunctionEnd )"; - SinglePassRunAndMatch(spv_asm, true); + SinglePassRunAndMatch(spv_asm, true); } TEST_F(CCPTest, HandleCompositeWithUndef) { @@ -552,8 +553,8 @@ TEST_F(CCPTest, HandleCompositeWithUndef) { OpFunctionEnd )"; - auto res = SinglePassRunToBinary(spv_asm, true); - EXPECT_EQ(std::get<1>(res), opt::Pass::Status::SuccessWithoutChange); + auto res = SinglePassRunToBinary(spv_asm, true); + EXPECT_EQ(std::get<1>(res), Pass::Status::SuccessWithoutChange); } TEST_F(CCPTest, SkipSpecConstantInstrucitons) { @@ -579,8 +580,8 @@ TEST_F(CCPTest, SkipSpecConstantInstrucitons) { OpFunctionEnd )"; - auto res = SinglePassRunToBinary(spv_asm, true); - EXPECT_EQ(std::get<1>(res), opt::Pass::Status::SuccessWithoutChange); + auto res = SinglePassRunToBinary(spv_asm, true); + EXPECT_EQ(std::get<1>(res), Pass::Status::SuccessWithoutChange); } TEST_F(CCPTest, UpdateSubsequentPhisToVarying) { @@ -639,8 +640,8 @@ OpReturn OpFunctionEnd )"; - auto res = SinglePassRunToBinary(text, true); - EXPECT_EQ(std::get<1>(res), opt::Pass::Status::SuccessWithoutChange); + auto res = SinglePassRunToBinary(text, true); + EXPECT_EQ(std::get<1>(res), Pass::Status::SuccessWithoutChange); } TEST_F(CCPTest, UndefInPhi) { @@ -678,7 +679,7 @@ TEST_F(CCPTest, UndefInPhi) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } // Just test to make sure the constant fold rules are being used. Will rely on @@ -704,7 +705,7 @@ TEST_F(CCPTest, UseConstantFoldingRules) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } // Test for #1300. Previously value for %5 would not settle during simulation. @@ -731,7 +732,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunToBinary(text, true); + SinglePassRunToBinary(text, true); } TEST_F(CCPTest, NullBranchCondition) { @@ -762,7 +763,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(CCPTest, UndefBranchCondition) { @@ -793,7 +794,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(CCPTest, NullSwitchCondition) { @@ -823,7 +824,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(CCPTest, UndefSwitchCondition) { @@ -853,7 +854,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } // Test for #1361. @@ -888,8 +889,10 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } #endif } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/cfg_cleanup_test.cpp b/3rdparty/spirv-tools/test/opt/cfg_cleanup_test.cpp index 0f6082e17..369c76670 100644 --- a/3rdparty/spirv-tools/test/opt/cfg_cleanup_test.cpp +++ b/3rdparty/spirv-tools/test/opt/cfg_cleanup_test.cpp @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using CFGCleanupTest = PassTest<::testing::Test>; TEST_F(CFGCleanupTest, RemoveUnreachableBlocks) { @@ -79,8 +81,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - declarations + body_before, declarations + body_after, true, true); + SinglePassRunAndCheck(declarations + body_before, + declarations + body_after, true, true); } TEST_F(CFGCleanupTest, RemoveDecorations) { @@ -142,7 +144,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before, after, true, true); + SinglePassRunAndCheck(before, after, true, true); } TEST_F(CFGCleanupTest, UpdatePhis) { @@ -226,7 +228,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before, after, true, true); + SinglePassRunAndCheck(before, after, true, true); } TEST_F(CFGCleanupTest, RemoveNamedLabels) { @@ -261,7 +263,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before, after, true, true); + SinglePassRunAndCheck(before, after, true, true); } TEST_F(CFGCleanupTest, RemovePhiArgsFromFarBlocks) { @@ -359,7 +361,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before, after, true, true); + SinglePassRunAndCheck(before, after, true, true); } TEST_F(CFGCleanupTest, RemovePhiConstantArgs) { @@ -438,6 +440,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before, after, true, true); + SinglePassRunAndCheck(before, after, true, true); } -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/combine_access_chains_test.cpp b/3rdparty/spirv-tools/test/opt/combine_access_chains_test.cpp new file mode 100644 index 000000000..ab9e185b6 --- /dev/null +++ b/3rdparty/spirv-tools/test/opt/combine_access_chains_test.cpp @@ -0,0 +1,754 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using CombineAccessChainsTest = PassTest<::testing::Test>; + +#ifdef SPIRV_EFFCEE +TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainConstant) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpAccessChain %ptr_Workgroup_uint %var %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromInBoundsAccessChainConstant) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpInBoundsAccessChain %ptr_Workgroup_uint %var %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainCombineConstant) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2 +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int2]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpAccessChain %ptr_Workgroup_uint %var %uint_1 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_1 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainNonConstant) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[ld1:%\w+]] = OpLoad +; CHECK: [[ld2:%\w+]] = OpLoad +; CHECK: [[add:%\w+]] = OpIAdd [[int]] [[ld1]] [[ld2]] +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[add]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%local_var = OpVariable %ptr_Function_uint Function +%ld1 = OpLoad %uint %local_var +%gep = OpAccessChain %ptr_Workgroup_uint %var %ld1 +%ld2 = OpLoad %uint %local_var +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %ld2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainExtraIndices) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int1:%\w+]] = OpConstant [[int]] 1 +; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int1]] [[int2]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4 +%uint_array_4_array_4_array_4 = OpTypeArray %uint_array_4_array_4 %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4 +%ptr_Workgroup_uint_array_4_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4_array_4_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpAccessChain %ptr_Workgroup_uint_array_4 %var %uint_1 %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_2 %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, + PtrAccessChainFromPtrAccessChainCombineElementOperand) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int6]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, + PtrAccessChainFromPtrAccessChainOnlyElementOperand) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4 +; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]] +; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpPtrAccessChain [[ptr_array]] [[var]] [[int6]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, + PtrAccessChainFromPtrAccessCombineNonElementIndex) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int3]] [[int3]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, + AccessChainFromPtrAccessChainOnlyElementOperand) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int3]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, AccessChainFromPtrAccessChainAppend) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int1:%\w+]] = OpConstant [[int]] 1 +; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int1]] [[int2]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_1 %uint_2 +%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, AccessChainFromAccessChainAppend) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int1:%\w+]] = OpConstant [[int]] 1 +; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int1]] [[int2]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%ptr_gep = OpAccessChain %ptr_Workgroup_uint_array_4 %var %uint_1 +%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, NonConstantStructSlide) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[ld:%\w+]] = OpLoad +; CHECK: OpPtrAccessChain {{%\w+}} [[var]] [[ld]] [[int0]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%struct = OpTypeStruct %uint %uint +%ptr_Workgroup_struct = OpTypePointer Workgroup %struct +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%wg_var = OpVariable %ptr_Workgroup_struct Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%1 = OpLabel +%func_var = OpVariable %ptr_Function_uint Function +%ld = OpLoad %uint %func_var +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_struct %wg_var %ld +%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, DontCombineNonConstantStructSlide) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[ld:%\w+]] = OpLoad +; CHECK: [[gep:%\w+]] = OpAccessChain +; CHECK: OpPtrAccessChain {{%\w+}} [[gep]] [[ld]] [[int0]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_4 = OpConstant %uint 4 +%struct = OpTypeStruct %uint %uint +%struct_array_4 = OpTypeArray %struct %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_struct = OpTypePointer Workgroup %struct +%ptr_Workgroup_struct_array_4 = OpTypePointer Workgroup %struct_array_4 +%wg_var = OpVariable %ptr_Workgroup_struct_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%1 = OpLabel +%func_var = OpVariable %ptr_Function_uint Function +%ld = OpLoad %uint %func_var +%gep = OpAccessChain %ptr_Workgroup_struct %wg_var %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %ld %uint_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, CombineNonConstantStructSlideElement) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[ld:%\w+]] = OpLoad +; CHECK: [[add:%\w+]] = OpIAdd {{%\w+}} [[ld]] [[ld]] +; CHECK: OpPtrAccessChain {{%\w+}} [[var]] [[add]] [[int0]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_4 = OpConstant %uint 4 +%struct = OpTypeStruct %uint %uint +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_struct = OpTypePointer Workgroup %struct +%wg_var = OpVariable %ptr_Workgroup_struct Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%1 = OpLabel +%func_var = OpVariable %ptr_Function_uint Function +%ld = OpLoad %uint %func_var +%gep = OpPtrAccessChain %ptr_Workgroup_struct %wg_var %ld +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %ld %uint_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromInBoundsPtrAccessChain) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4 +; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]] +; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpPtrAccessChain [[ptr_array]] [[var]] [[int6]] +OpCapability Shader +OpCapability VariablePointers +OpCapability Addresses +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, InBoundsPtrAccessChainFromPtrAccessChain) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4 +; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]] +; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpPtrAccessChain [[ptr_array]] [[var]] [[int6]] +OpCapability Shader +OpCapability VariablePointers +OpCapability Addresses +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, + InBoundsPtrAccessChainFromInBoundsPtrAccessChain) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4 +; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]] +; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpInBoundsPtrAccessChain [[ptr_array]] [[var]] [[int6]] +OpCapability Shader +OpCapability VariablePointers +OpCapability Addresses +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, NoIndexAccessChains) { + const std::string text = R"( +; CHECK: [[var:%\w+]] = OpVariable +; CHECK-NOT: OpConstant +; CHECK: [[gep:%\w+]] = OpAccessChain {{%\w+}} [[var]] +; CHECK: OpAccessChain {{%\w+}} [[var]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%var = OpVariable %ptr_Workgroup_uint Workgroup +%void_func = OpTypeFunction %void +%func = OpFunction %void None %void_func +%1 = OpLabel +%gep1 = OpAccessChain %ptr_Workgroup_uint %var +%gep2 = OpAccessChain %ptr_Workgroup_uint %gep1 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, NoIndexPtrAccessChains) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: [[gep:%\w+]] = OpPtrAccessChain {{%\w+}} [[var]] [[int0]] +; CHECK: OpCopyObject {{%\w+}} [[gep]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%var = OpVariable %ptr_Workgroup_uint Workgroup +%void_func = OpTypeFunction %void +%func = OpFunction %void None %void_func +%1 = OpLabel +%gep1 = OpPtrAccessChain %ptr_Workgroup_uint %var %uint_0 +%gep2 = OpAccessChain %ptr_Workgroup_uint %gep1 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, NoIndexPtrAccessChains2) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: OpPtrAccessChain {{%\w+}} [[var]] [[int0]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%var = OpVariable %ptr_Workgroup_uint Workgroup +%void_func = OpTypeFunction %void +%func = OpFunction %void None %void_func +%1 = OpLabel +%gep1 = OpAccessChain %ptr_Workgroup_uint %var +%gep2 = OpPtrAccessChain %ptr_Workgroup_uint %gep1 %uint_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, CombineMixedSign) { + const std::string text = R"( +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: [[uint2:%\w+]] = OpConstant [[uint]] 2 +; CHECK: OpInBoundsPtrAccessChain {{%\w+}} [[var]] [[uint2]] +OpCapability Shader +OpCapability VariablePointers +OpCapability Addresses +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%int = OpTypeInt 32 1 +%uint_1 = OpConstant %uint 1 +%int_1 = OpConstant %int 1 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%var = OpVariable %ptr_Workgroup_uint Workgroup +%void_func = OpTypeFunction %void +%func = OpFunction %void None %void_func +%1 = OpLabel +%gep1 = OpInBoundsPtrAccessChain %ptr_Workgroup_uint %var %uint_1 +%gep2 = OpInBoundsPtrAccessChain %ptr_Workgroup_uint %gep1 %int_1 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} +#endif // SPIRV_EFFCEE + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/common_uniform_elim_test.cpp b/3rdparty/spirv-tools/test/opt/common_uniform_elim_test.cpp index cf9cc1bb8..f5199ed87 100644 --- a/3rdparty/spirv-tools/test/opt/common_uniform_elim_test.cpp +++ b/3rdparty/spirv-tools/test/opt/common_uniform_elim_test.cpp @@ -13,12 +13,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" +#include +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using CommonUniformElimTest = PassTest<::testing::Test>; TEST_F(CommonUniformElimTest, Basic1) { @@ -165,8 +167,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(CommonUniformElimTest, Basic2) { @@ -329,8 +331,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(CommonUniformElimTest, Basic3) { @@ -448,8 +450,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(CommonUniformElimTest, Loop) { @@ -659,8 +661,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(CommonUniformElimTest, Volatile1) { @@ -809,8 +811,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(CommonUniformElimTest, Volatile2) { @@ -918,10 +920,9 @@ OpReturn OpFunctionEnd )"; - opt::Pass::Status res = - std::get<1>(SinglePassRunAndDisassemble( - text, true, false)); - EXPECT_EQ(res, opt::Pass::Status::SuccessWithoutChange); + Pass::Status res = std::get<1>( + SinglePassRunAndDisassemble(text, true, false)); + EXPECT_EQ(res, Pass::Status::SuccessWithoutChange); } TEST_F(CommonUniformElimTest, Volatile3) { @@ -1036,10 +1037,186 @@ OpReturn OpFunctionEnd )"; - opt::Pass::Status res = - std::get<1>(SinglePassRunAndDisassemble( - text, true, false)); - EXPECT_EQ(res, opt::Pass::Status::SuccessWithoutChange); + Pass::Status res = std::get<1>( + SinglePassRunAndDisassemble(text, true, false)); + EXPECT_EQ(res, Pass::Status::SuccessWithoutChange); +} + +TEST_F(CommonUniformElimTest, IteratorDanglingPointer) { + // Note: This test exemplifies the following: + // - Existing common uniform (%_) load kept in place and shared + // + // #version 140 + // in vec4 BaseColor; + // in float fi; + // + // layout(std140) uniform U_t + // { + // bool g_B; + // float g_F; + // } ; + // + // uniform float alpha; + // uniform bool alpha_B; + // + // void main() + // { + // vec4 v = BaseColor; + // if (g_B) { + // v = v * g_F; + // if (alpha_B) + // v = v * alpha; + // else + // v = v * fi; + // } + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor %fi +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_B" +OpMemberName %U_t 1 "g_F" +OpName %alpha "alpha" +OpName %alpha_B "alpha_B" +OpName %_ "" +OpName %gl_FragColor "gl_FragColor" +OpName %fi "fi" +OpMemberDecorate %U_t 0 Offset 0 +OpMemberDecorate %U_t 1 Offset 4 +OpDecorate %U_t Block +OpDecorate %_ DescriptorSet 0 +%void = OpTypeVoid +%12 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%U_t = OpTypeStruct %uint %float +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%int_1 = OpConstant %int 1 +%_ptr_Uniform_float = OpTypePointer Uniform %float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%alpha = OpVariable %_ptr_Uniform_float Uniform +%alpha_B = OpVariable %_ptr_Uniform_uint Uniform +)"; + + const std::string before = + R"(%main = OpFunction %void None %12 +%26 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%27 = OpLoad %v4float %BaseColor +OpStore %v %27 +%28 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%29 = OpLoad %uint %28 +%30 = OpINotEqual %bool %29 %uint_0 +OpSelectionMerge %31 None +OpBranchConditional %30 %31 %32 +%32 = OpLabel +%47 = OpLoad %v4float %v +OpStore %gl_FragColor %47 +OpReturn +%31 = OpLabel +%33 = OpAccessChain %_ptr_Uniform_float %_ %int_1 +%34 = OpLoad %float %33 +%35 = OpLoad %v4float %v +%36 = OpVectorTimesScalar %v4float %35 %34 +OpStore %v %36 +%37 = OpLoad %uint %alpha_B +%38 = OpIEqual %bool %37 %uint_0 +OpSelectionMerge %43 None +OpBranchConditional %38 %43 %39 +%39 = OpLabel +%40 = OpLoad %float %alpha +%41 = OpLoad %v4float %v +%42 = OpVectorTimesScalar %v4float %41 %40 +OpStore %v %42 +OpBranch %50 +%50 = OpLabel +%51 = OpLoad %v4float %v +OpStore %gl_FragColor %51 +OpReturn +%43 = OpLabel +%44 = OpLoad %float %fi +%45 = OpLoad %v4float %v +%46 = OpVectorTimesScalar %v4float %45 %44 +OpStore %v %46 +OpBranch %60 +%60 = OpLabel +%61 = OpLoad %v4float %v +OpStore %gl_FragColor %61 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %12 +%28 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%29 = OpLoad %v4float %BaseColor +OpStore %v %29 +%54 = OpLoad %U_t %_ +%55 = OpCompositeExtract %uint %54 0 +%32 = OpINotEqual %bool %55 %uint_0 +OpSelectionMerge %33 None +OpBranchConditional %32 %33 %34 +%34 = OpLabel +%35 = OpLoad %v4float %v +OpStore %gl_FragColor %35 +OpReturn +%33 = OpLabel +%58 = OpLoad %float %alpha +%57 = OpCompositeExtract %float %54 1 +%38 = OpLoad %v4float %v +%39 = OpVectorTimesScalar %v4float %38 %57 +OpStore %v %39 +%40 = OpLoad %uint %alpha_B +%41 = OpIEqual %bool %40 %uint_0 +OpSelectionMerge %42 None +OpBranchConditional %41 %42 %43 +%43 = OpLabel +%45 = OpLoad %v4float %v +%46 = OpVectorTimesScalar %v4float %45 %58 +OpStore %v %46 +OpBranch %47 +%47 = OpLabel +%48 = OpLoad %v4float %v +OpStore %gl_FragColor %48 +OpReturn +%42 = OpLabel +%49 = OpLoad %float %fi +%50 = OpLoad %v4float %v +%51 = OpVectorTimesScalar %v4float %50 %49 +OpStore %v %51 +OpBranch %52 +%52 = OpLabel +%53 = OpLoad %v4float %v +OpStore %gl_FragColor %53 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } #ifdef SPIRV_EFFCEE @@ -1150,7 +1327,7 @@ TEST_F(CommonUniformElimTest, MixedConstantAndNonConstantIndexes) { )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } #endif // SPIRV_EFFCEE // TODO(greg-lunarg): Add tests to verify handling of these cases: @@ -1159,4 +1336,6 @@ TEST_F(CommonUniformElimTest, MixedConstantAndNonConstantIndexes) { // non-structured control flow // Others? -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/compact_ids_test.cpp b/3rdparty/spirv-tools/test/opt/compact_ids_test.cpp index d5fcfb95c..b1e4b2cbb 100644 --- a/3rdparty/spirv-tools/test/opt/compact_ids_test.cpp +++ b/3rdparty/spirv-tools/test/opt/compact_ids_test.cpp @@ -12,18 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include +#include +#include +#include "gmock/gmock.h" #include "spirv-tools/libspirv.hpp" #include "spirv-tools/optimizer.hpp" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" -#include "pass_fixture.h" -#include "pass_utils.h" - +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using CompactIdsTest = PassTest<::testing::Test>; TEST_F(CompactIdsTest, PassOff) { @@ -43,7 +45,7 @@ OpMemoryModel Physical32 OpenCL SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(before, after, false, false); + SinglePassRunAndCheck(before, after, false, false); } TEST_F(CompactIdsTest, PassOn) { @@ -87,7 +89,7 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(before, after, false, false); + SinglePassRunAndCheck(before, after, false, false); } TEST(CompactIds, InstructionResultIsUpdated) { @@ -105,7 +107,7 @@ OpMemoryModel Logical Simple OpEntryPoint GLCompute %100 "main" %200 = OpTypeVoid %300 = OpTypeFunction %200 -%100 = OpFunction %300 None %200 +%100 = OpFunction %200 None %300 %400 = OpLabel OpReturn OpFunctionEnd @@ -134,7 +136,7 @@ OpMemoryModel Logical Simple OpEntryPoint GLCompute %1 "main" %2 = OpTypeVoid %3 = OpTypeFunction %2 -%1 = OpFunction %3 None %2 +%1 = OpFunction %2 None %3 %4 = OpLabel OpReturn OpFunctionEnd @@ -149,7 +151,7 @@ OpMemoryModel Logical Simple OpEntryPoint GLCompute %100 "main" %200 = OpTypeVoid %300 = OpTypeFunction %200 -%100 = OpFunction %300 None %200 +%100 = OpFunction %200 None %300 %400 = OpLabel OpReturn OpFunctionEnd @@ -183,7 +185,7 @@ OpMemoryModel Logical Simple OpEntryPoint GLCompute %1 "main" %2 = OpTypeVoid %3 = OpTypeFunction %2 -%1 = OpFunction %3 None %2 +%1 = OpFunction %2 None %3 %4 = OpLabel OpReturn OpFunctionEnd @@ -192,4 +194,86 @@ OpFunctionEnd EXPECT_THAT(disassembly, ::testing::Eq(expected)); } -} // anonymous namespace +// Test context consistency check after invalidating +// CFG and others by compact IDs Pass. +// Uses a GLSL shader with named labels for variety +TEST(CompactIds, ConsistentCheck) { + const std::string input(R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_A %out_var_SV_TARGET +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %main "main" +OpName %in_var_A "in.var.A" +OpName %out_var_SV_TARGET "out.var.SV_TARGET" +OpDecorate %in_var_A Location 0 +OpDecorate %out_var_SV_TARGET Location 0 +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%in_var_A = OpVariable %_ptr_Input_v4float Input +%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %3 +%5 = OpLabel +%12 = OpLoad %v4float %in_var_A +%23 = OpVectorShuffle %v4float %12 %12 0 0 0 1 +OpStore %out_var_SV_TARGET %23 +OpReturn +OpFunctionEnd +)"); + + spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_1); + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(context, nullptr); + + CompactIdsPass compact_id_pass; + context->BuildInvalidAnalyses(compact_id_pass.GetPreservedAnalyses()); + const auto status = compact_id_pass.Run(context.get()); + ASSERT_NE(status, Pass::Status::Failure); + EXPECT_TRUE(context->IsConsistent()); + + // Test output just in case + std::vector binary; + context->module()->ToBinary(&binary, false); + std::string disassembly; + tools.Disassemble(binary, &disassembly, + SpirvTools::kDefaultDisassembleOption); + + const std::string expected(R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_A %out_var_SV_TARGET +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %main "main" +OpName %in_var_A "in.var.A" +OpName %out_var_SV_TARGET "out.var.SV_TARGET" +OpDecorate %in_var_A Location 0 +OpDecorate %out_var_SV_TARGET Location 0 +%void = OpTypeVoid +%5 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%in_var_A = OpVariable %_ptr_Input_v4float Input +%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %5 +%10 = OpLabel +%11 = OpLoad %v4float %in_var_A +%12 = OpVectorShuffle %v4float %11 %11 0 0 0 1 +OpStore %out_var_SV_TARGET %12 +OpReturn +OpFunctionEnd +)"); + + EXPECT_THAT(disassembly, ::testing::Eq(expected)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/constant_manager_test.cpp b/3rdparty/spirv-tools/test/opt/constant_manager_test.cpp new file mode 100644 index 000000000..57dea6512 --- /dev/null +++ b/3rdparty/spirv-tools/test/opt/constant_manager_test.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/constants.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { +namespace analysis { +namespace { + +using ConstantManagerTest = ::testing::Test; + +TEST_F(ConstantManagerTest, GetDefiningInstruction) { + const std::string text = R"( +%int = OpTypeInt 32 0 +%1 = OpTypeStruct %int +%2 = OpTypeStruct %int + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(context, nullptr); + + Type* struct_type_1 = context->get_type_mgr()->GetType(1); + StructConstant struct_const_1(struct_type_1->AsStruct()); + Instruction* const_inst_1 = + context->get_constant_mgr()->GetDefiningInstruction(&struct_const_1, 1); + EXPECT_EQ(const_inst_1->type_id(), 1); + + Type* struct_type_2 = context->get_type_mgr()->GetType(2); + StructConstant struct_const_2(struct_type_2->AsStruct()); + Instruction* const_inst_2 = + context->get_constant_mgr()->GetDefiningInstruction(&struct_const_2, 2); + EXPECT_EQ(const_inst_2->type_id(), 2); +} + +TEST_F(ConstantManagerTest, GetDefiningInstruction2) { + const std::string text = R"( +%int = OpTypeInt 32 0 +%1 = OpTypeStruct %int +%2 = OpTypeStruct %int +%3 = OpConstantNull %1 +%4 = OpConstantNull %2 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(context, nullptr); + + Type* struct_type_1 = context->get_type_mgr()->GetType(1); + NullConstant struct_const_1(struct_type_1->AsStruct()); + Instruction* const_inst_1 = + context->get_constant_mgr()->GetDefiningInstruction(&struct_const_1, 1); + EXPECT_EQ(const_inst_1->type_id(), 1); + EXPECT_EQ(const_inst_1->result_id(), 3); + + Type* struct_type_2 = context->get_type_mgr()->GetType(2); + NullConstant struct_const_2(struct_type_2->AsStruct()); + Instruction* const_inst_2 = + context->get_constant_mgr()->GetDefiningInstruction(&struct_const_2, 2); + EXPECT_EQ(const_inst_2->type_id(), 2); + EXPECT_EQ(const_inst_2->result_id(), 4); +} + +} // namespace +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/copy_prop_array_test.cpp b/3rdparty/spirv-tools/test/opt/copy_prop_array_test.cpp index a58acdfd7..dcce77d22 100644 --- a/3rdparty/spirv-tools/test/opt/copy_prop_array_test.cpp +++ b/3rdparty/spirv-tools/test/opt/copy_prop_array_test.cpp @@ -13,19 +13,16 @@ // limitations under the License. #include +#include -#include - -#include "assembly_builder.h" -#include "pass_fixture.h" +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; -using ir::Instruction; -using ir::IRContext; -using opt::PassManager; - using CopyPropArrayPassTest = PassTest<::testing::Test>; #ifdef SPIRV_EFFCEE @@ -105,7 +102,7 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - SinglePassRunAndMatch(before, false); + SinglePassRunAndMatch(before, false); } TEST_F(CopyPropArrayPassTest, BasicPropagateArrayWithName) { @@ -185,10 +182,10 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - SinglePassRunAndMatch(before, false); + SinglePassRunAndMatch(before, false); } -// Propagate 2d array. This test identifing a copy through multiple levels. +// Propagate 2d array. This test identifying a copy through multiple levels. // Also has to traverse multiple OpAccessChains. TEST_F(CopyPropArrayPassTest, Propagate2DArray) { const std::string text = @@ -274,7 +271,94 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); +} + +// Propagate 2d array. This test identifying a copy through multiple levels. +// Also has to traverse multiple OpAccessChains. +TEST_F(CopyPropArrayPassTest, Propagate2DArrayWithMultiLevelExtract) { + const std::string text = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %type_MyCBuffer "type.MyCBuffer" +OpMemberName %type_MyCBuffer 0 "Data" +OpName %MyCBuffer "MyCBuffer" +OpName %main "main" +OpName %in_var_INDEX "in.var.INDEX" +OpName %out_var_SV_Target "out.var.SV_Target" +OpDecorate %_arr_v4float_uint_2 ArrayStride 16 +OpDecorate %_arr__arr_v4float_uint_2_uint_2 ArrayStride 32 +OpMemberDecorate %type_MyCBuffer 0 Offset 0 +OpDecorate %type_MyCBuffer Block +OpDecorate %in_var_INDEX Flat +OpDecorate %in_var_INDEX Location 0 +OpDecorate %out_var_SV_Target Location 0 +OpDecorate %MyCBuffer DescriptorSet 0 +OpDecorate %MyCBuffer Binding 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_2 = OpConstant %uint 2 +%_arr_v4float_uint_2 = OpTypeArray %v4float %uint_2 +%_arr__arr_v4float_uint_2_uint_2 = OpTypeArray %_arr_v4float_uint_2 %uint_2 +%type_MyCBuffer = OpTypeStruct %_arr__arr_v4float_uint_2_uint_2 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer +%void = OpTypeVoid +%14 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_2_0 = OpTypeArray %v4float %uint_2 +%_arr__arr_v4float_uint_2_0_uint_2 = OpTypeArray %_arr_v4float_uint_2_0 %uint_2 +%_ptr_Function__arr__arr_v4float_uint_2_0_uint_2 = OpTypePointer Function %_arr__arr_v4float_uint_2_0_uint_2 +%int_0 = OpConstant %int 0 +%_ptr_Uniform__arr__arr_v4float_uint_2_uint_2 = OpTypePointer Uniform %_arr__arr_v4float_uint_2_uint_2 +%_ptr_Function__arr_v4float_uint_2_0 = OpTypePointer Function %_arr_v4float_uint_2_0 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK: OpVariable +; CHECK: OpVariable +; CHECK: OpAccessChain +; CHECK: [[new_address:%\w+]] = OpAccessChain %_ptr_Uniform__arr__arr_v4float_uint_2_uint_2 %MyCBuffer %int_0 +%main = OpFunction %void None %14 +%25 = OpLabel +%26 = OpVariable %_ptr_Function__arr_v4float_uint_2_0 Function +%27 = OpVariable %_ptr_Function__arr__arr_v4float_uint_2_0_uint_2 Function +%28 = OpLoad %int %in_var_INDEX +%29 = OpAccessChain %_ptr_Uniform__arr__arr_v4float_uint_2_uint_2 %MyCBuffer %int_0 +%30 = OpLoad %_arr__arr_v4float_uint_2_uint_2 %29 +%32 = OpCompositeExtract %v4float %30 0 0 +%33 = OpCompositeExtract %v4float %30 0 1 +%34 = OpCompositeConstruct %_arr_v4float_uint_2_0 %32 %33 +%36 = OpCompositeExtract %v4float %30 1 0 +%37 = OpCompositeExtract %v4float %30 1 1 +%38 = OpCompositeConstruct %_arr_v4float_uint_2_0 %36 %37 +%39 = OpCompositeConstruct %_arr__arr_v4float_uint_2_0_uint_2 %34 %38 +; CHECK: OpStore +OpStore %27 %39 +%40 = OpAccessChain %_ptr_Function__arr_v4float_uint_2_0 %27 %28 +%42 = OpAccessChain %_ptr_Function_v4float %40 %28 +%43 = OpLoad %v4float %42 +; CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform__arr_v4float_uint_2 [[new_address]] %28 +; CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[ac1]] %28 +; CHECK: [[load:%\w+]] = OpLoad %v4float [[ac2]] +; CHECK: OpStore %out_var_SV_Target [[load]] +OpStore %out_var_SV_Target %43 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndMatch(text, false); } // Test decomposing an object when we need to "rewrite" a store. @@ -359,7 +443,7 @@ TEST_F(CopyPropArrayPassTest, DecomposeObjectForArrayStore) { SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } // Test decomposing an object when we need to "rewrite" a store. @@ -447,7 +531,7 @@ TEST_F(CopyPropArrayPassTest, DecomposeObjectForStructStore) { SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(CopyPropArrayPassTest, CopyViaInserts) { @@ -534,7 +618,7 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - SinglePassRunAndMatch(before, false); + SinglePassRunAndMatch(before, false); } #endif // SPIRV_EFFCEE @@ -610,10 +694,10 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } // This test will place a load where it is not dominated by the store. We @@ -698,10 +782,10 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } // This test has a partial store to the variable. We cannot propagate in this @@ -777,10 +861,10 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } // This test does not have a proper copy of an object. We cannot propagate in @@ -855,10 +939,10 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } TEST_F(CopyPropArrayPassTest, BadCopyViaInserts1) { @@ -937,10 +1021,10 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } TEST_F(CopyPropArrayPassTest, BadCopyViaInserts2) { @@ -1019,10 +1103,10 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } TEST_F(CopyPropArrayPassTest, BadCopyViaInserts3) { @@ -1099,10 +1183,10 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } TEST_F(CopyPropArrayPassTest, AtomicAdd) { @@ -1179,6 +1263,9 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, after, true, true); + SinglePassRunAndCheck(before, after, true, true); } + } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dead_branch_elim_test.cpp b/3rdparty/spirv-tools/test/opt/dead_branch_elim_test.cpp index 3ad5fc315..29084e3bb 100644 --- a/3rdparty/spirv-tools/test/opt/dead_branch_elim_test.cpp +++ b/3rdparty/spirv-tools/test/opt/dead_branch_elim_test.cpp @@ -13,13 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using DeadBranchElimTest = PassTest<::testing::Test>; TEST_F(DeadBranchElimTest, IfThenElseTrue) { @@ -99,8 +101,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(DeadBranchElimTest, IfThenElseFalse) { @@ -180,8 +182,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(DeadBranchElimTest, IfThenTrue) { @@ -228,13 +230,13 @@ OpName %gl_FragColor "gl_FragColor" %17 = OpLabel %v = OpVariable %_ptr_Function_v4float Function %18 = OpLoad %v4float %BaseColor -OpStore %v %18 -OpSelectionMerge %19 None +OpStore %v %18 +OpSelectionMerge %19 None OpBranchConditional %true %20 %19 %20 = OpLabel %21 = OpLoad %v4float %v %22 = OpFMul %v4float %21 %15 -OpStore %v %22 +OpStore %v %22 OpBranch %19 %19 = OpLabel %23 = OpLoad %v4float %v @@ -262,8 +264,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(DeadBranchElimTest, IfThenFalse) { @@ -310,13 +312,13 @@ OpName %gl_FragColor "gl_FragColor" %17 = OpLabel %v = OpVariable %_ptr_Function_v4float Function %18 = OpLoad %v4float %BaseColor -OpStore %v %18 -OpSelectionMerge %19 None +OpStore %v %18 +OpSelectionMerge %19 None OpBranchConditional %false %20 %19 %20 = OpLabel %21 = OpLoad %v4float %v %22 = OpFMul %v4float %21 %15 -OpStore %v %22 +OpStore %v %22 OpBranch %19 %19 = OpLabel %23 = OpLoad %v4float %v @@ -339,8 +341,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(DeadBranchElimTest, IfThenElsePhiTrue) { @@ -412,8 +414,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(DeadBranchElimTest, IfThenElsePhiFalse) { @@ -485,8 +487,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(DeadBranchElimTest, CompoundIfThenElseFalse) { @@ -558,7 +560,7 @@ OpDecorate %_ DescriptorSet 0 %25 = OpLabel %v = OpVariable %_ptr_Function_v4float Function OpSelectionMerge %26 None -OpBranchConditional %false %27 %28 +OpBranchConditional %false %27 %28 %27 = OpLabel %29 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 %30 = OpLoad %uint %29 @@ -605,12 +607,12 @@ OpBranch %28 %37 = OpINotEqual %bool %36 %uint_0 OpSelectionMerge %38 None OpBranchConditional %37 %39 %40 -%39 = OpLabel -OpStore %v %23 -OpBranch %38 %40 = OpLabel OpStore %v %21 OpBranch %38 +%39 = OpLabel +OpStore %v %23 +OpBranch %38 %38 = OpLabel OpBranch %26 %26 = OpLabel @@ -620,8 +622,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(DeadBranchElimTest, PreventOrphanMerge) { @@ -684,8 +686,8 @@ OpKill OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(DeadBranchElimTest, HandleOrphanMerge) { @@ -731,7 +733,7 @@ OpReturnValue %13 %22 = OpLabel OpReturnValue %15 %20 = OpLabel -%23 = OpUndef %v4float +%23 = OpUndef %v4float OpReturnValue %23 OpFunctionEnd )"; @@ -745,8 +747,8 @@ OpReturnValue %13 OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(DeadBranchElimTest, KeepContinueTargetWhenKillAfterMerge) { @@ -802,7 +804,7 @@ OpBranchConditional %17 %19 %18 OpBranch %13 %18 = OpLabel OpSelectionMerge %20 None -OpBranchConditional %false %21 %20 +OpBranchConditional %false %21 %20 %21 = OpLabel OpBranch %13 %20 = OpLabel @@ -843,8 +845,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(DeadBranchElimTest, DecorateDeleted) { @@ -919,13 +921,13 @@ OpName %gl_FragColor "gl_FragColor" %17 = OpLabel %v = OpVariable %_ptr_Function_v4float Function %18 = OpLoad %v4float %BaseColor -OpStore %v %18 -OpSelectionMerge %19 None +OpStore %v %18 +OpSelectionMerge %19 None OpBranchConditional %false %20 %19 %20 = OpLabel %21 = OpLoad %v4float %v %22 = OpFMul %v4float %21 %15 -OpStore %v %22 +OpStore %v %22 OpBranch %19 %19 = OpLabel %23 = OpLoad %v4float %v @@ -948,8 +950,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs_before + before, predefs_after + after, true, true); + SinglePassRunAndCheck(predefs_before + before, + predefs_after + after, true, true); } TEST_F(DeadBranchElimTest, LoopInDeadBranch) { @@ -1008,7 +1010,7 @@ OpDecorate %OutColor Location 0 %23 = OpLoad %v4float %BaseColor OpStore %v %23 OpSelectionMerge %24 None -OpBranchConditional %false %25 %24 +OpBranchConditional %false %25 %24 %25 = OpLabel OpStore %i %int_0 OpBranch %26 @@ -1053,8 +1055,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(DeadBranchElimTest, SwitchLiveCase) { @@ -1115,20 +1117,20 @@ OpDecorate %BaseColor Location 0 const std::string before = R"(%main = OpFunction %void None %6 %21 = OpLabel -OpSelectionMerge %22 None +OpSelectionMerge %22 None OpSwitch %int_1 %23 0 %24 1 %25 2 %26 %23 = OpLabel OpStore %OutColor %19 -OpBranch %22 +OpBranch %22 %24 = OpLabel OpStore %OutColor %13 -OpBranch %22 +OpBranch %22 %25 = OpLabel OpStore %OutColor %15 -OpBranch %22 +OpBranch %22 %26 = OpLabel OpStore %OutColor %17 -OpBranch %22 +OpBranch %22 %22 = OpLabel OpReturn OpFunctionEnd @@ -1146,8 +1148,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(DeadBranchElimTest, SwitchLiveDefault) { @@ -1208,20 +1210,20 @@ OpDecorate %BaseColor Location 0 const std::string before = R"(%main = OpFunction %void None %6 %21 = OpLabel -OpSelectionMerge %22 None +OpSelectionMerge %22 None OpSwitch %int_7 %23 0 %24 1 %25 2 %26 %23 = OpLabel OpStore %OutColor %19 -OpBranch %22 +OpBranch %22 %24 = OpLabel OpStore %OutColor %13 -OpBranch %22 +OpBranch %22 %25 = OpLabel OpStore %OutColor %15 -OpBranch %22 +OpBranch %22 %26 = OpLabel OpStore %OutColor %17 -OpBranch %22 +OpBranch %22 %22 = OpLabel OpReturn OpFunctionEnd @@ -1239,8 +1241,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(DeadBranchElimTest, SwitchLiveCaseBreakFromLoop) { @@ -1304,13 +1306,13 @@ OpStore %oc %17 OpBranch %28 %33 = OpLabel OpStore %oc %19 -OpBranch %28 +OpBranch %28 %34 = OpLabel OpStore %oc %21 -OpBranch %28 +OpBranch %28 %31 = OpLabel OpStore %oc %23 -OpBranch %28 +OpBranch %28 %29 = OpLabel OpBranchConditional %false %27 %28 %28 = OpLabel @@ -1342,8 +1344,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(predefs + before, - predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } #ifdef SPIRV_EFFCEE @@ -1375,7 +1377,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, LeaveContinueBackedgeExtraBlock) { const std::string text = R"( @@ -1412,7 +1414,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, RemovePhiWithUnreachableContinue) { @@ -1453,7 +1455,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, UnreachableLoopMergeAndContinueTargets) { @@ -1502,7 +1504,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, EarlyReconvergence) { const std::string text = R"( @@ -1548,7 +1550,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, RemoveUnreachableBlocksFloating) { @@ -1572,7 +1574,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, RemoveUnreachableBlocksFloatingJoin) { @@ -1609,7 +1611,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, RemoveUnreachableBlocksDeadPhi) { @@ -1645,7 +1647,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, RemoveUnreachableBlocksPartiallyDeadPhi) { @@ -1654,12 +1656,12 @@ TEST_F(DeadBranchElimTest, RemoveUnreachableBlocksPartiallyDeadPhi) { ; CHECK-NEXT: [[param:%\w+]] = OpFunctionParameter ; CHECK-NEXT: OpLabel ; CHECK-NEXT: OpBranchConditional [[param]] [[merge:%\w+]] [[br:%\w+]] -; CHECK-NEXT: [[br]] = OpLabel -; CHECK-NEXT: OpBranch [[merge]] ; CHECK-NEXT: [[merge]] = OpLabel ; CHECK-NEXT: [[phi:%\w+]] = OpPhi %bool %true %2 %false [[br]] ; CHECK-NEXT: OpLogicalNot %bool [[phi]] ; CHECK-NEXT: OpReturn +; CHECK-NEXT: [[br]] = OpLabel +; CHECK-NEXT: OpBranch [[merge]] ; CHECK-NEXT: OpFunctionEnd OpCapability Kernel OpCapability Linkage @@ -1687,7 +1689,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, LiveHeaderDeadPhi) { @@ -1719,7 +1721,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, ExtraBackedgeBlocksLive) { @@ -1762,7 +1764,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, ExtraBackedgeBlocksUnreachable) { @@ -1772,9 +1774,10 @@ TEST_F(DeadBranchElimTest, ExtraBackedgeBlocksUnreachable) { ; CHECK-NEXT: [[header]] = OpLabel ; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue:%\w+]] None ; CHECK-NEXT: OpBranch [[merge]] +; CHECK-NEXT: [[merge]] = OpLabel +; CHECK-NEXT: OpReturn ; CHECK-NEXT: [[continue]] = OpLabel ; CHECK-NEXT: OpBranch [[header]] -; CHECK-NEXT: [[merge]] = OpLabel OpCapability Kernel OpCapability Linkage OpMemoryModel Logical OpenCL @@ -1804,7 +1807,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, NoUnnecessaryChanges) { @@ -1832,8 +1835,8 @@ OpUnreachable OpFunctionEnd )"; - auto result = SinglePassRunToBinary(text, true); - EXPECT_EQ(std::get<1>(result), opt::Pass::Status::SuccessWithoutChange); + auto result = SinglePassRunToBinary(text, true); + EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange); } TEST_F(DeadBranchElimTest, ExtraBackedgePartiallyDead) { @@ -1841,6 +1844,7 @@ TEST_F(DeadBranchElimTest, ExtraBackedgePartiallyDead) { ; CHECK: OpLabel ; CHECK: [[header:%\w+]] = OpLabel ; CHECK: OpLoopMerge [[merge:%\w+]] [[continue:%\w+]] None +; CHECK: [[merge]] = OpLabel ; CHECK: [[continue]] = OpLabel ; CHECK: OpBranch [[extra:%\w+]] ; CHECK: [[extra]] = OpLabel @@ -1851,7 +1855,6 @@ TEST_F(DeadBranchElimTest, ExtraBackedgePartiallyDead) { ; CHECK-NEXT: OpBranch [[backedge:%\w+]] ; CHECK-NEXT: [[backedge:%\w+]] = OpLabel ; CHECK-NEXT: OpBranch [[header]] -; CHECK-NEXT: [[merge]] = OpLabel OpCapability Kernel OpCapability Linkage OpMemoryModel Logical OpenCL @@ -1887,7 +1890,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, UnreachableContinuePhiInMerge) { @@ -1974,7 +1977,7 @@ TEST_F(DeadBranchElimTest, UnreachableContinuePhiInMerge) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(DeadBranchElimTest, NonStructuredIf) { @@ -2000,7 +2003,195 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, ReorderBlocks) { + const std::string text = R"( +; CHECK: OpLabel +; CHECK: OpBranch [[label:%\w+]] +; CHECK: [[label:%\w+]] = OpLabel +; CHECK-NEXT: OpLogicalNot +; CHECK-NEXT: OpBranch [[label:%\w+]] +; CHECK: [[label]] = OpLabel +; CHECK-NEXT: OpReturn +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpSelectionMerge %3 None +OpBranchConditional %true %2 %3 +%3 = OpLabel +OpReturn +%2 = OpLabel +%not = OpLogicalNot %bool %true +OpBranch %3 +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, ReorderBlocksMultiple) { + // Checks are not important. The validation post optimization is the + // important part. + const std::string text = R"( +; CHECK: OpLabel +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpSelectionMerge %3 None +OpBranchConditional %true %2 %3 +%3 = OpLabel +OpReturn +%2 = OpLabel +OpBranch %4 +%4 = OpLabel +OpBranch %3 +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, ReorderBlocksMultiple2) { + // Checks are not important. The validation post optimization is the + // important part. + const std::string text = R"( +; CHECK: OpLabel +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpSelectionMerge %3 None +OpBranchConditional %true %2 %3 +%3 = OpLabel +OpBranch %5 +%5 = OpLabel +OpReturn +%2 = OpLabel +OpBranch %4 +%4 = OpLabel +OpBranch %3 +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithEarlyExit1) { + // Checks that if a selection merge construct contains a conditional branch + // to the merge node, then the OpSelectionMerge instruction is positioned + // correctly. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%undef_bool = OpUndef %bool +)"; + + const std::string body = + R"( +; CHECK: OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpBranch [[taken_branch:%\w+]] +; CHECK-NEXT: [[taken_branch]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[merge:%\w+]] +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[merge]] {{%\w+}} +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpSelectionMerge %outer_merge None +OpBranchConditional %true %bb1 %bb3 +%bb1 = OpLabel +OpBranchConditional %undef_bool %outer_merge %bb2 +%bb2 = OpLabel +OpBranch %outer_merge +%bb3 = OpLabel +OpBranch %outer_merge +%outer_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithEarlyExit2) { + // Checks that if a selection merge construct contains a conditional branch + // to the merge node, then the OpSelectionMerge instruction is positioned + // correctly. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%undef_bool = OpUndef %bool +)"; + + const std::string body = + R"( +; CHECK: OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK-NEXT: [[bb1]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[inner_merge:%\w+]] +; CHECK: [[inner_merge]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[outer_merge:%\w+]] +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[outer_merge]:%\w+]] {{%\w+}} +; CHECK: [[outer_merge]] = OpLabel +; CHECK-NEXT: OpReturn +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpSelectionMerge %outer_merge None +OpBranchConditional %true %bb1 %bb5 +%bb1 = OpLabel +OpSelectionMerge %inner_merge None +OpBranchConditional %undef_bool %bb2 %bb3 +%bb2 = OpLabel +OpBranch %inner_merge +%bb3 = OpLabel +OpBranch %inner_merge +%inner_merge = OpLabel +OpBranchConditional %undef_bool %outer_merge %bb4 +%bb4 = OpLabel +OpBranch %outer_merge +%bb5 = OpLabel +OpBranch %outer_merge +%outer_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); } #endif @@ -2009,4 +2200,6 @@ OpFunctionEnd // More complex control flow // Others? -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dead_insert_elim_test.cpp b/3rdparty/spirv-tools/test/opt/dead_insert_elim_test.cpp index 7a434363a..8ae6894d8 100644 --- a/3rdparty/spirv-tools/test/opt/dead_insert_elim_test.cpp +++ b/3rdparty/spirv-tools/test/opt/dead_insert_elim_test.cpp @@ -13,13 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using DeadInsertElimTest = PassTest<::testing::Test>; TEST_F(DeadInsertElimTest, InsertAfterInsertElim) { @@ -164,8 +166,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - before_predefs + before, after_predefs + after, true, true); + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); } TEST_F(DeadInsertElimTest, DeadInsertInChainWithPhi) { @@ -343,8 +345,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - before_predefs + before, after_predefs + after, true, true); + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); } TEST_F(DeadInsertElimTest, DeadInsertTwoPasses) { @@ -557,129 +559,13 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - before_predefs + before, after_predefs + after, true, true); -} - -TEST_F(DeadInsertElimTest, DeadInsertInCycleToDo) { - // Dead insert in chain with cycle. Demonstrates analysis can handle - // cycles in chains. - // - // TODO(greg-lunarg): Improve algorithm to remove dead insert into v.y. Will - // likely require similar logic to ADCE. - // - // Note: The SPIR-V assembly has had store/load elimination - // performed to allow the inserts and extracts to directly - // reference each other. - // - // #version 450 - // - // layout (location=0) in vec4 In0; - // layout (location=1) in float In1; - // layout (location=2) in float In2; - // layout (location=0) out vec4 OutColor; - // - // layout(std140, binding = 0 ) uniform _Globals_ - // { - // int g_n ; - // }; - // - // void main() - // { - // vec2 v = vec2(0.0, 1.0); - // for (int i = 0; i < g_n; i++) { - // v.x = v.x + 1; - // v.y = v.y * 0.9; // dead - // } - // OutColor = vec4(v.x); - // } - - const std::string assembly = - R"(OpCapability Shader -%1 = OpExtInstImport "GLSL.std.450" -OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %main "main" %OutColor %In0 %In1 %In2 -OpExecutionMode %main OriginUpperLeft -OpSource GLSL 450 -OpName %main "main" -OpName %_Globals_ "_Globals_" -OpMemberName %_Globals_ 0 "g_n" -OpName %_ "" -OpName %OutColor "OutColor" -OpName %In0 "In0" -OpName %In1 "In1" -OpName %In2 "In2" -OpMemberDecorate %_Globals_ 0 Offset 0 -OpDecorate %_Globals_ Block -OpDecorate %_ DescriptorSet 0 -OpDecorate %_ Binding 0 -OpDecorate %OutColor Location 0 -OpDecorate %In0 Location 0 -OpDecorate %In1 Location 1 -OpDecorate %In2 Location 2 -%void = OpTypeVoid -%10 = OpTypeFunction %void -%float = OpTypeFloat 32 -%v2float = OpTypeVector %float 2 -%_ptr_Function_v2float = OpTypePointer Function %v2float -%float_0 = OpConstant %float 0 -%float_1 = OpConstant %float 1 -%16 = OpConstantComposite %v2float %float_0 %float_1 -%int = OpTypeInt 32 1 -%_ptr_Function_int = OpTypePointer Function %int -%int_0 = OpConstant %int 0 -%_Globals_ = OpTypeStruct %int -%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ -%_ = OpVariable %_ptr_Uniform__Globals_ Uniform -%_ptr_Uniform_int = OpTypePointer Uniform %int -%bool = OpTypeBool -%float_0_75 = OpConstant %float 0.75 -%int_1 = OpConstant %int 1 -%v4float = OpTypeVector %float 4 -%_ptr_Output_v4float = OpTypePointer Output %v4float -%OutColor = OpVariable %_ptr_Output_v4float Output -%_ptr_Input_v4float = OpTypePointer Input %v4float -%In0 = OpVariable %_ptr_Input_v4float Input -%_ptr_Input_float = OpTypePointer Input %float -%In1 = OpVariable %_ptr_Input_float Input -%In2 = OpVariable %_ptr_Input_float Input -%main = OpFunction %void None %10 -%29 = OpLabel -OpBranch %30 -%30 = OpLabel -%31 = OpPhi %v2float %16 %29 %32 %33 -%34 = OpPhi %int %int_0 %29 %35 %33 -OpLoopMerge %36 %33 None -OpBranch %37 -%37 = OpLabel -%38 = OpAccessChain %_ptr_Uniform_int %_ %int_0 -%39 = OpLoad %int %38 -%40 = OpSLessThan %bool %34 %39 -OpBranchConditional %40 %41 %36 -%41 = OpLabel -%42 = OpCompositeExtract %float %31 0 -%43 = OpFAdd %float %42 %float_1 -%44 = OpCompositeInsert %v2float %43 %31 0 -%45 = OpCompositeExtract %float %44 1 -%46 = OpFMul %float %45 %float_0_75 -%32 = OpCompositeInsert %v2float %46 %44 1 -OpBranch %33 -%33 = OpLabel -%35 = OpIAdd %int %34 %int_1 -OpBranch %30 -%36 = OpLabel -%47 = OpCompositeExtract %float %31 0 -%48 = OpCompositeConstruct %v4float %47 %47 %47 %47 -OpStore %OutColor %48 -OpReturn -OpFunctionEnd -)"; - - SinglePassRunAndCheck(assembly, assembly, true, - true); + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); } // TODO(greg-lunarg): Add tests to verify handling of these cases: // -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dead_variable_elim_test.cpp b/3rdparty/spirv-tools/test/opt/dead_variable_elim_test.cpp index 676719cf4..fca13a8e2 100644 --- a/3rdparty/spirv-tools/test/opt/dead_variable_elim_test.cpp +++ b/3rdparty/spirv-tools/test/opt/dead_variable_elim_test.cpp @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using DeadVariableElimTest = PassTest<::testing::Test>; // %dead is unused. Make sure we remove it along with its name. @@ -64,8 +66,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, after, true, - true); + SinglePassRunAndCheck(before, after, true, true); } // Since %dead is exported, make sure we keep it. It could be referenced @@ -94,8 +95,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, before, true, - true); + SinglePassRunAndCheck(before, before, true, true); } // Delete %dead because it is unreferenced. Then %initializer becomes @@ -144,8 +144,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, after, true, - true); + SinglePassRunAndCheck(before, after, true, true); } // Delete %dead because it is unreferenced. In this case, the initialized has @@ -198,8 +197,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, after, true, - true); + SinglePassRunAndCheck(before, after, true, true); } // Keep %live because it is used, and its initializer. @@ -229,8 +227,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, before, true, - true); + SinglePassRunAndCheck(before, before, true, true); } // This test that the decoration associated with a variable are removed when the @@ -293,7 +290,9 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, after, true, - true); + SinglePassRunAndCheck(before, after, true, true); } + } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/decoration_manager_test.cpp b/3rdparty/spirv-tools/test/opt/decoration_manager_test.cpp index b503316fd..cf82e8e66 100644 --- a/3rdparty/spirv-tools/test/opt/decoration_manager_test.cpp +++ b/3rdparty/spirv-tools/test/opt/decoration_manager_test.cpp @@ -13,21 +13,23 @@ // limitations under the License. #include +#include +#include +#include -#include - +#include "gmock/gmock.h" #include "source/opt/build_module.h" #include "source/opt/decoration_manager.h" #include "source/opt/ir_context.h" #include "source/spirv_constant.h" -#include "unit_spirv.h" +#include "test/unit_spirv.h" +namespace spvtools { +namespace opt { +namespace analysis { namespace { using spvtest::MakeVector; -using spvtools::ir::Instruction; -using spvtools::ir::IRContext; -using spvtools::opt::analysis::DecorationManager; class DecorationManagerTest : public ::testing::Test { public: @@ -61,10 +63,10 @@ class DecorationManagerTest : public ::testing::Test { tools_.SetMessageConsumer(consumer_); } - virtual void TearDown() override { error_message_.clear(); } + void TearDown() override { error_message_.clear(); } DecorationManager* GetDecorationManager(const std::string& text) { - context_ = spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_2, consumer_, text); + context_ = BuildModule(SPV_ENV_UNIVERSAL_1_2, consumer_, text); if (context_.get()) return context_->get_decoration_mgr(); else @@ -104,8 +106,8 @@ class DecorationManagerTest : public ::testing::Test { spvtools::MessageConsumer GetConsumer() { return consumer_; } private: - spvtools::SpirvTools - tools_; // An instance for calling SPIRV-Tools functionalities. + // An instance for calling SPIRV-Tools functionalities. + spvtools::SpirvTools tools_; std::unique_ptr context_; spvtools::MessageConsumer consumer_; uint32_t disassemble_options_; @@ -114,7 +116,7 @@ class DecorationManagerTest : public ::testing::Test { TEST_F(DecorationManagerTest, ComparingDecorationsWithDiffOpcodesDecorateDecorateId) { - spvtools::ir::IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); // This parameter can be interprated both as { SpvDecorationConstant } // and also as a list of IDs: { 22 } const std::vector param{SpvDecorationConstant}; @@ -133,7 +135,7 @@ TEST_F(DecorationManagerTest, TEST_F(DecorationManagerTest, ComparingDecorationsWithDiffOpcodesDecorateDecorateString) { - spvtools::ir::IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); // This parameter can be interprated both as { SpvDecorationConstant } // and also as a null-terminated string with a single character with value 22. const std::vector param{SpvDecorationConstant}; @@ -151,7 +153,7 @@ TEST_F(DecorationManagerTest, } TEST_F(DecorationManagerTest, ComparingDecorationsWithDiffDecorateParam) { - spvtools::ir::IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); // OpDecorate %1 Constant Instruction inst1(&ir_context, SpvOpDecorate, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {1u}}, @@ -166,7 +168,7 @@ TEST_F(DecorationManagerTest, ComparingDecorationsWithDiffDecorateParam) { } TEST_F(DecorationManagerTest, ComparingDecorationsWithDiffDecorateIdParam) { - spvtools::ir::IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); // OpDecorate %1 Constant Instruction inst1( &ir_context, SpvOpDecorateId, 0u, 0u, @@ -181,7 +183,7 @@ TEST_F(DecorationManagerTest, ComparingDecorationsWithDiffDecorateIdParam) { } TEST_F(DecorationManagerTest, ComparingDecorationsWithDiffDecorateStringParam) { - spvtools::ir::IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); // OpDecorate %1 Constant Instruction inst1(&ir_context, SpvOpDecorateStringGOOGLE, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {1u}}, @@ -196,7 +198,7 @@ TEST_F(DecorationManagerTest, ComparingDecorationsWithDiffDecorateStringParam) { } TEST_F(DecorationManagerTest, ComparingSameDecorationsOnDiffTargetAllowed) { - spvtools::ir::IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); // OpDecorate %1 Constant Instruction inst1(&ir_context, SpvOpDecorate, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {1u}}, @@ -211,7 +213,7 @@ TEST_F(DecorationManagerTest, ComparingSameDecorationsOnDiffTargetAllowed) { } TEST_F(DecorationManagerTest, ComparingSameDecorationIdsOnDiffTargetAllowed) { - spvtools::ir::IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); Instruction inst1( &ir_context, SpvOpDecorateId, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {1u}}, {SPV_OPERAND_TYPE_DECORATION, {44}}}); @@ -225,7 +227,7 @@ TEST_F(DecorationManagerTest, ComparingSameDecorationIdsOnDiffTargetAllowed) { TEST_F(DecorationManagerTest, ComparingSameDecorationStringsOnDiffTargetAllowed) { - spvtools::ir::IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); Instruction inst1(&ir_context, SpvOpDecorateStringGOOGLE, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {1u}}, {SPV_OPERAND_TYPE_LITERAL_STRING, MakeVector("hello")}}); @@ -238,7 +240,7 @@ TEST_F(DecorationManagerTest, } TEST_F(DecorationManagerTest, ComparingSameDecorationsOnDiffTargetDisallowed) { - spvtools::ir::IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); // OpDecorate %1 Constant Instruction inst1(&ir_context, SpvOpDecorate, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {1u}}, @@ -253,7 +255,7 @@ TEST_F(DecorationManagerTest, ComparingSameDecorationsOnDiffTargetDisallowed) { } TEST_F(DecorationManagerTest, ComparingMemberDecorationsOnSameTypeDiffMember) { - spvtools::ir::IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); // OpMemberDecorate %1 0 Constant Instruction inst1(&ir_context, SpvOpMemberDecorate, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {1u}}, @@ -271,7 +273,7 @@ TEST_F(DecorationManagerTest, ComparingMemberDecorationsOnSameTypeDiffMember) { TEST_F(DecorationManagerTest, ComparingSameMemberDecorationsOnDiffTargetAllowed) { - spvtools::ir::IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); // OpMemberDecorate %1 0 Constant Instruction inst1(&ir_context, SpvOpMemberDecorate, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {1u}}, @@ -289,7 +291,7 @@ TEST_F(DecorationManagerTest, TEST_F(DecorationManagerTest, ComparingSameMemberDecorationsOnDiffTargetDisallowed) { - spvtools::ir::IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); // OpMemberDecorate %1 0 Constant Instruction inst1(&ir_context, SpvOpMemberDecorate, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {1u}}, @@ -481,11 +483,10 @@ OpGroupDecorate %3 %1 )"; DecorationManager* decoManager = GetDecorationManager(spirv); EXPECT_THAT(GetErrorMessage(), ""); - decoManager->RemoveDecorationsFrom( - 1u, [](const spvtools::ir::Instruction& inst) { - return inst.opcode() == SpvOpDecorate && - inst.GetSingleWordInOperand(0u) == 3u; - }); + decoManager->RemoveDecorationsFrom(1u, [](const Instruction& inst) { + return inst.opcode() == SpvOpDecorate && + inst.GetSingleWordInOperand(0u) == 3u; + }); auto decorations = decoManager->GetDecorationsFor(1u, false); EXPECT_THAT(GetErrorMessage(), ""); @@ -533,12 +534,11 @@ OpGroupDecorate %3 %1 )"; DecorationManager* decoManager = GetDecorationManager(spirv); EXPECT_THAT(GetErrorMessage(), ""); - decoManager->RemoveDecorationsFrom( - 1u, [](const spvtools::ir::Instruction& inst) { - return inst.opcode() == SpvOpDecorate && - inst.GetSingleWordInOperand(0u) == 3u && - inst.GetSingleWordInOperand(1u) == SpvDecorationBuiltIn; - }); + decoManager->RemoveDecorationsFrom(1u, [](const Instruction& inst) { + return inst.opcode() == SpvOpDecorate && + inst.GetSingleWordInOperand(0u) == 3u && + inst.GetSingleWordInOperand(1u) == SpvDecorationBuiltIn; + }); auto decorations = decoManager->GetDecorationsFor(1u, false); EXPECT_THAT(GetErrorMessage(), ""); @@ -734,6 +734,123 @@ OpDecorate %5 Aliased EXPECT_THAT(ModuleToText(), expected_binary); } +TEST_F(DecorationManagerTest, CloneSomeDecorations) { + const std::string spirv = R"(OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorate %1 RelaxedPrecision +OpDecorate %1 Restrict +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Function %2 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpFunction %4 None %5 +%7 = OpLabel +%1 = OpVariable %3 Function +%8 = OpUndef %2 +OpReturn +OpFunctionEnd +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_EQ(GetErrorMessage(), ""); + + // Check cloning OpDecorate including group decorations. + auto decorations = decoManager->GetDecorationsFor(8u, false); + EXPECT_EQ(GetErrorMessage(), ""); + EXPECT_TRUE(decorations.empty()); + + decoManager->CloneDecorations(1u, 8u, {SpvDecorationRelaxedPrecision}); + decorations = decoManager->GetDecorationsFor(8u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + std::string expected_decorations = + R"(OpDecorate %8 RelaxedPrecision +)"; + EXPECT_EQ(ToText(decorations), expected_decorations); + + const std::string expected_binary = R"(OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorate %1 RelaxedPrecision +OpDecorate %1 Restrict +OpDecorate %8 RelaxedPrecision +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Function %2 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpFunction %4 None %5 +%7 = OpLabel +%1 = OpVariable %3 Function +%8 = OpUndef %2 +OpReturn +OpFunctionEnd +)"; + EXPECT_EQ(ModuleToText(), expected_binary); +} + +// Test cloning decoration for an id that is decorated via a group decoration. +TEST_F(DecorationManagerTest, CloneSomeGroupDecorations) { + const std::string spirv = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 RelaxedPrecision +OpDecorate %1 Restrict +%1 = OpDecorationGroup +OpGroupDecorate %1 %2 +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Function %3 +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpFunction %5 None %6 +%8 = OpLabel +%2 = OpVariable %4 Function +%9 = OpUndef %3 +OpReturn +OpFunctionEnd +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_EQ(GetErrorMessage(), ""); + + // Check cloning OpDecorate including group decorations. + auto decorations = decoManager->GetDecorationsFor(9u, false); + EXPECT_EQ(GetErrorMessage(), ""); + EXPECT_TRUE(decorations.empty()); + + decoManager->CloneDecorations(2u, 9u, {SpvDecorationRelaxedPrecision}); + decorations = decoManager->GetDecorationsFor(9u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + std::string expected_decorations = + R"(OpDecorate %9 RelaxedPrecision +)"; + EXPECT_EQ(ToText(decorations), expected_decorations); + + const std::string expected_binary = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 RelaxedPrecision +OpDecorate %1 Restrict +%1 = OpDecorationGroup +OpGroupDecorate %1 %2 +OpDecorate %9 RelaxedPrecision +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Function %3 +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpFunction %5 None %6 +%8 = OpLabel +%2 = OpVariable %4 Function +%9 = OpUndef %3 +OpReturn +OpFunctionEnd +)"; + EXPECT_EQ(ModuleToText(), expected_binary); +} + TEST_F(DecorationManagerTest, HaveTheSameDecorationsWithoutGroupsTrue) { const std::string spirv = R"( OpCapability Shader @@ -1160,3 +1277,6 @@ OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "hello" } } // namespace +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/def_use_test.cpp b/3rdparty/spirv-tools/test/opt/def_use_test.cpp index cdc829be1..3b856ce7f 100644 --- a/3rdparty/spirv-tools/test/opt/def_use_test.cpp +++ b/3rdparty/spirv-tools/test/opt/def_use_test.cpp @@ -13,33 +13,36 @@ // limitations under the License. #include +#include +#include #include +#include +#include -#include -#include - -#include "opt/build_module.h" -#include "opt/def_use_manager.h" -#include "opt/ir_context.h" -#include "opt/module.h" -#include "pass_fixture.h" -#include "pass_utils.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" #include "spirv-tools/libspirv.hpp" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { +namespace analysis { namespace { using ::testing::Contains; using ::testing::UnorderedElementsAre; using ::testing::UnorderedElementsAreArray; -using namespace spvtools; -using spvtools::opt::analysis::DefUseManager; - // Returns the number of uses of |id|. -uint32_t NumUses(const std::unique_ptr& context, uint32_t id) { +uint32_t NumUses(const std::unique_ptr& context, uint32_t id) { uint32_t count = 0; context->get_def_use_mgr()->ForEachUse( - id, [&count](ir::Instruction*, uint32_t) { ++count; }); + id, [&count](Instruction*, uint32_t) { ++count; }); return count; } @@ -47,18 +50,18 @@ uint32_t NumUses(const std::unique_ptr& context, uint32_t id) { // // If |id| is used multiple times in a single instruction, that instruction's // opcode will appear a corresponding number of times. -std::vector GetUseOpcodes(const std::unique_ptr& context, +std::vector GetUseOpcodes(const std::unique_ptr& context, uint32_t id) { std::vector opcodes; context->get_def_use_mgr()->ForEachUse( - id, [&opcodes](ir::Instruction* user, uint32_t) { + id, [&opcodes](Instruction* user, uint32_t) { opcodes.push_back(user->opcode()); }); return opcodes; } // Disassembles the given |inst| and returns the disassembly. -std::string DisassembleInst(ir::Instruction* inst) { +std::string DisassembleInst(Instruction* inst) { SpirvTools tools(SPV_ENV_UNIVERSAL_1_1); std::vector binary; @@ -103,7 +106,7 @@ void CheckDef(const InstDefUse& expected_defs_uses, } } -using UserMap = std::unordered_map>; +using UserMap = std::unordered_map>; // Creates a mapping of all definitions to their users (except OpConstant). // @@ -112,7 +115,7 @@ UserMap BuildAllUsers(const DefUseManager* mgr, uint32_t idBound) { UserMap userMap; for (uint32_t id = 0; id != idBound; ++id) { if (mgr->GetDef(id)) { - mgr->ForEachUser(id, [id, &userMap](ir::Instruction* user) { + mgr->ForEachUser(id, [id, &userMap](Instruction* user) { if (user->opcode() != SpvOpConstant) { userMap[id].push_back(user); } @@ -190,13 +193,13 @@ TEST_P(ParseDefUseTest, Case) { // Build module. const std::vector text = {tc.text}; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text), SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Analyze def and use. - opt::analysis::DefUseManager manager(context->module()); + DefUseManager manager(context->module()); CheckDef(tc.du, manager.id_to_defs()); CheckUse(tc.du, &manager, context->module()->IdBound()); @@ -483,8 +486,8 @@ INSTANTIATE_TEST_CASE_P( {6, { // Can't check constants properly - //"%8 = OpConstant %6 0", - //"%18 = OpConstant %6 1", + // "%8 = OpConstant %6 0", + // "%18 = OpConstant %6 1", "%7 = OpPhi %6 %8 %4 %9 %5", "%9 = OpIAdd %6 %7 %8", } @@ -504,7 +507,7 @@ INSTANTIATE_TEST_CASE_P( {9, {"%7 = OpPhi %6 %8 %4 %9 %5"}}, {10, { - //"%12 = OpConstant %10 1.0", + // "%12 = OpConstant %10 1.0", "%11 = OpPhi %10 %12 %4 %13 %5", "%13 = OpFAdd %10 %11 %12", } @@ -587,7 +590,7 @@ struct ReplaceUseCase { using ReplaceUseTest = ::testing::TestWithParam; // Disassembles the given |module| and returns the disassembly. -std::string DisassembleModule(ir::Module* module) { +std::string DisassembleModule(Module* module) { SpirvTools tools(SPV_ENV_UNIVERSAL_1_1); std::vector binary; @@ -606,13 +609,13 @@ TEST_P(ReplaceUseTest, Case) { // Build module. const std::vector text = {tc.before}; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text), SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Force a re-build of def-use manager. - context->InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); + context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse); (void)context->get_def_use_mgr(); // Do the substitution. @@ -831,8 +834,8 @@ INSTANTIATE_TEST_CASE_P( {6, { // Can't properly check constants - //"%8 = OpConstant %6 0", - //"%18 = OpConstant %6 1", + // "%8 = OpConstant %6 0", + // "%18 = OpConstant %6 1", "%7 = OpPhi %6 %8 %4 %13 %5", "%9 = OpIAdd %6 %7 %8" } @@ -853,7 +856,7 @@ INSTANTIATE_TEST_CASE_P( {10, { "%11 = OpPhi %10 %12 %4 %13 %5", - //"%12 = OpConstant %10 1", + // "%12 = OpConstant %10 1", "%13 = OpFAdd %10 %9 %12" } }, @@ -961,13 +964,13 @@ TEST_P(KillDefTest, Case) { // Build module. const std::vector text = {tc.before}; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text), SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Analyze def and use. - opt::analysis::DefUseManager manager(context->module()); + DefUseManager manager(context->module()); // Do the substitution. for (const auto id : tc.ids_to_kill) context->KillDef(id); @@ -1085,14 +1088,14 @@ INSTANTIATE_TEST_CASE_P( {4, { "%7 = OpPhi %6 %8 %4 %9 %5", - //"%11 = OpPhi %10 %12 %4 %13 %5", + // "%11 = OpPhi %10 %12 %4 %13 %5", } }, {5, { "OpBranch %5", "%7 = OpPhi %6 %8 %4 %9 %5", - //"%11 = OpPhi %10 %12 %4 %13 %5", + // "%11 = OpPhi %10 %12 %4 %13 %5", "OpLoopMerge %19 %5 None", "OpBranchConditional %17 %5 %19", } @@ -1100,35 +1103,35 @@ INSTANTIATE_TEST_CASE_P( {6, { // Can't properly check constants - //"%8 = OpConstant %6 0", - //"%18 = OpConstant %6 1", + // "%8 = OpConstant %6 0", + // "%18 = OpConstant %6 1", "%7 = OpPhi %6 %8 %4 %9 %5", - //"%9 = OpIAdd %6 %7 %8" + // "%9 = OpIAdd %6 %7 %8" } }, {7, {"%17 = OpSLessThan %16 %7 %18"}}, {8, { "%7 = OpPhi %6 %8 %4 %9 %5", - //"%9 = OpIAdd %6 %7 %8", + // "%9 = OpIAdd %6 %7 %8", } }, // {9, {"%7 = OpPhi %6 %8 %4 %13 %5"}}, {10, { - //"%11 = OpPhi %10 %12 %4 %13 %5", - //"%12 = OpConstant %10 1", + // "%11 = OpPhi %10 %12 %4 %13 %5", + // "%12 = OpConstant %10 1", "%13 = OpFAdd %10 %11 %12" } }, // {11, {"%13 = OpFAdd %10 %11 %12"}}, {12, { - //"%11 = OpPhi %10 %12 %4 %13 %5", + // "%11 = OpPhi %10 %12 %4 %13 %5", "%13 = OpFAdd %10 %11 %12" } }, - //{13, {"%11 = OpPhi %10 %12 %4 %13 %5"}}, + // {13, {"%11 = OpPhi %10 %12 %4 %13 %5"}}, {16, {"%17 = OpSLessThan %16 %7 %18"}}, {17, {"OpBranchConditional %17 %5 %19"}}, {18, {"%17 = OpSLessThan %16 %7 %18"}}, @@ -1238,13 +1241,13 @@ TEST(DefUseTest, OpSwitch) { " OpReturnValue %6 " " OpFunctionEnd"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, original_text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Force a re-build of def-use manager. - context->InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); + context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse); (void)context->get_def_use_mgr(); // Do a bunch replacements. @@ -1327,12 +1330,12 @@ TEST_P(AnalyzeInstDefUseTest, Case) { auto tc = GetParam(); // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.module_text); ASSERT_NE(nullptr, context); // Analyze the instructions. - opt::analysis::DefUseManager manager(context->module()); + DefUseManager manager(context->module()); CheckDef(tc.expected_define_use, manager.id_to_defs()); CheckUse(tc.expected_define_use, &manager, context->module()->IdBound()); @@ -1370,16 +1373,15 @@ INSTANTIATE_TEST_CASE_P( using AnalyzeInstDefUse = ::testing::Test; TEST(AnalyzeInstDefUse, UseWithNoResultId) { - ir::IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); // Analyze the instructions. - opt::analysis::DefUseManager manager(context.module()); + DefUseManager manager(context.module()); - ir::Instruction label(&context, SpvOpLabel, 0, 2, {}); + Instruction label(&context, SpvOpLabel, 0, 2, {}); manager.AnalyzeInstDefUse(&label); - ir::Instruction branch(&context, SpvOpBranch, 0, 0, - {{SPV_OPERAND_TYPE_ID, {2}}}); + Instruction branch(&context, SpvOpBranch, 0, 0, {{SPV_OPERAND_TYPE_ID, {2}}}); manager.AnalyzeInstDefUse(&branch); context.module()->SetIdBound(3); @@ -1400,14 +1402,14 @@ TEST(AnalyzeInstDefUse, AddNewInstruction) { const std::string input = "%1 = OpTypeBool"; // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input); ASSERT_NE(nullptr, context); // Analyze the instructions. - opt::analysis::DefUseManager manager(context->module()); + DefUseManager manager(context->module()); - ir::Instruction newInst(context.get(), SpvOpConstantTrue, 1, 2, {}); + Instruction newInst(context.get(), SpvOpConstantTrue, 1, 2, {}); manager.AnalyzeInstDefUse(&newInst); InstDefUse expected = { @@ -1439,17 +1441,17 @@ TEST_P(KillInstTest, Case) { auto tc = GetParam(); // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.before, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Force a re-build of the def-use manager. - context->InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); + context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse); (void)context->get_def_use_mgr(); // KillInst - context->module()->ForEachInst([&tc, &context](ir::Instruction* inst) { + context->module()->ForEachInst([&tc, &context](Instruction* inst) { if (tc.indices_for_inst_to_kill.count(inst->result_id())) { context->KillInst(inst); } @@ -1566,12 +1568,12 @@ TEST_P(GetAnnotationsTest, Case) { const GetAnnotationsTestCase& tc = GetParam(); // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.code); ASSERT_NE(nullptr, context); // Get annotations - opt::analysis::DefUseManager manager(context->module()); + DefUseManager manager(context->module()); auto insts = manager.GetAnnotations(tc.id); // Check @@ -1693,21 +1695,25 @@ TEST_F(UpdateUsesTest, KeepOldUses) { // clang-format on }; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text), SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); DefUseManager* def_use_mgr = context->get_def_use_mgr(); - ir::Instruction* def = def_use_mgr->GetDef(9); - ir::Instruction* use = def_use_mgr->GetDef(10); + Instruction* def = def_use_mgr->GetDef(9); + Instruction* use = def_use_mgr->GetDef(10); def->SetOpcode(SpvOpCopyObject); def->SetInOperands({{SPV_OPERAND_TYPE_ID, {25}}}); context->UpdateDefUse(def); auto users = def_use_mgr->id_to_users(); - opt::analysis::UserEntry entry = {def, use}; + UserEntry entry = {def, use}; EXPECT_THAT(users, Contains(entry)); } // clang-format on -} // anonymous namespace + +} // namespace +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dominator_tree/CMakeLists.txt b/3rdparty/spirv-tools/test/opt/dominator_tree/CMakeLists.txt index 31a5f322e..f95a56da8 100644 --- a/3rdparty/spirv-tools/test/opt/dominator_tree/CMakeLists.txt +++ b/3rdparty/spirv-tools/test/opt/dominator_tree/CMakeLists.txt @@ -13,67 +13,18 @@ # limitations under the License. -add_spvtools_unittest(TARGET dominator_analysis_simple - SRCS ../function_utils.h - simple.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET dominator_analysis_post - SRCS ../function_utils.h - post.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET dominator_analysis_nested_ifs - SRCS ../function_utils.h - nested_ifs.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET dominator_analysis_nested_ifs_post - SRCS ../function_utils.h - nested_ifs_post.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET dominator_analysis_nested_loops - SRCS ../function_utils.h - nested_loops.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET dominator_analysis_nested_loops_with_unreachables - SRCS ../function_utils.h - nested_loops_with_unreachables.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET dominator_analysis_switch_case_fallthrough - SRCS ../function_utils.h - switch_case_fallthrough.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET dominator_analysis_unreachable_for - SRCS ../function_utils.h - unreachable_for.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET dominator_analysis_unreachable_for_post - SRCS ../function_utils.h - unreachable_for_post.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET dominator_generated - SRCS ../function_utils.h - generated.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET dominator_common_dominators - SRCS common_dominators.cpp - LIBS SPIRV-Tools-opt +add_spvtools_unittest(TARGET dominator_analysis + SRCS ../function_utils.h + common_dominators.cpp + generated.cpp + nested_ifs.cpp + nested_ifs_post.cpp + nested_loops.cpp + nested_loops_with_unreachables.cpp + post.cpp + simple.cpp + switch_case_fallthrough.cpp + unreachable_for.cpp + unreachable_for_post.cpp + LIBS SPIRV-Tools-opt ) diff --git a/3rdparty/spirv-tools/test/opt/dominator_tree/common_dominators.cpp b/3rdparty/spirv-tools/test/opt/dominator_tree/common_dominators.cpp index 612dc4543..dfa03e986 100644 --- a/3rdparty/spirv-tools/test/opt/dominator_tree/common_dominators.cpp +++ b/3rdparty/spirv-tools/test/opt/dominator_tree/common_dominators.cpp @@ -12,15 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include +#include +#include -#include "opt/build_module.h" -#include "opt/ir_context.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/ir_context.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using CommonDominatorsTest = ::testing::Test; const std::string text = R"( @@ -59,19 +62,18 @@ OpReturn OpFunctionEnd )"; -ir::BasicBlock* GetBlock(uint32_t id, std::unique_ptr& context) { +BasicBlock* GetBlock(uint32_t id, std::unique_ptr& context) { return context->get_instr_block(context->get_def_use_mgr()->GetDef(id)); } TEST(CommonDominatorsTest, SameBlock) { - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(nullptr, context); - ir::CFG cfg(context->module()); - opt::DominatorAnalysis* analysis = - context->GetDominatorAnalysis(&*context->module()->begin(), cfg); + DominatorAnalysis* analysis = + context->GetDominatorAnalysis(&*context->module()->begin()); for (auto& block : *context->module()->begin()) { EXPECT_EQ(&block, analysis->CommonDominator(&block, &block)); @@ -79,14 +81,13 @@ TEST(CommonDominatorsTest, SameBlock) { } TEST(CommonDominatorsTest, ParentAndChild) { - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(nullptr, context); - ir::CFG cfg(context->module()); - opt::DominatorAnalysis* analysis = - context->GetDominatorAnalysis(&*context->module()->begin(), cfg); + DominatorAnalysis* analysis = + context->GetDominatorAnalysis(&*context->module()->begin()); EXPECT_EQ( GetBlock(1u, context), @@ -100,14 +101,13 @@ TEST(CommonDominatorsTest, ParentAndChild) { } TEST(CommonDominatorsTest, BranchSplit) { - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(nullptr, context); - ir::CFG cfg(context->module()); - opt::DominatorAnalysis* analysis = - context->GetDominatorAnalysis(&*context->module()->begin(), cfg); + DominatorAnalysis* analysis = + context->GetDominatorAnalysis(&*context->module()->begin()); EXPECT_EQ( GetBlock(3u, context), @@ -118,14 +118,13 @@ TEST(CommonDominatorsTest, BranchSplit) { } TEST(CommonDominatorsTest, LoopContinueAndMerge) { - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(nullptr, context); - ir::CFG cfg(context->module()); - opt::DominatorAnalysis* analysis = - context->GetDominatorAnalysis(&*context->module()->begin(), cfg); + DominatorAnalysis* analysis = + context->GetDominatorAnalysis(&*context->module()->begin()); EXPECT_EQ( GetBlock(5u, context), @@ -133,14 +132,13 @@ TEST(CommonDominatorsTest, LoopContinueAndMerge) { } TEST(CommonDominatorsTest, NoCommonDominator) { - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(nullptr, context); - ir::CFG cfg(context->module()); - opt::DominatorAnalysis* analysis = - context->GetDominatorAnalysis(&*context->module()->begin(), cfg); + DominatorAnalysis* analysis = + context->GetDominatorAnalysis(&*context->module()->begin()); EXPECT_EQ(nullptr, analysis->CommonDominator(GetBlock(10u, context), GetBlock(11u, context))); @@ -148,4 +146,6 @@ TEST(CommonDominatorsTest, NoCommonDominator) { GetBlock(6u, context))); } -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dominator_tree/generated.cpp b/3rdparty/spirv-tools/test/opt/dominator_tree/generated.cpp index 80fd86208..43b723e93 100644 --- a/3rdparty/spirv-tools/test/opt/dominator_tree/generated.cpp +++ b/3rdparty/spirv-tools/test/opt/dominator_tree/generated.cpp @@ -12,27 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include #include #include #include #include -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" -#include "opt/dominator_analysis.h" -#include "opt/iterator.h" -#include "opt/pass.h" +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/iterator.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; // Check that x dominates y, and @@ -42,8 +41,8 @@ using PassClassTest = PassTest<::testing::Test>; // y does not strictly dominate x // if x == x then // x does not strictly dominate itself -void check_dominance(const opt::DominatorAnalysisBase& dom_tree, - const ir::Function* fn, uint32_t x, uint32_t y) { +void check_dominance(const DominatorAnalysisBase& dom_tree, const Function* fn, + uint32_t x, uint32_t y) { SCOPED_TRACE("Check dominance properties for Basic Block " + std::to_string(x) + " and " + std::to_string(y)); EXPECT_TRUE(dom_tree.Dominates(spvtest::GetBasicBlock(fn, x), @@ -59,8 +58,8 @@ void check_dominance(const opt::DominatorAnalysisBase& dom_tree, } // Check that x does not dominates y and vise versa -void check_no_dominance(const opt::DominatorAnalysisBase& dom_tree, - const ir::Function* fn, uint32_t x, uint32_t y) { +void check_no_dominance(const DominatorAnalysisBase& dom_tree, + const Function* fn, uint32_t x, uint32_t y) { SCOPED_TRACE("Check no domination for Basic Block " + std::to_string(x) + " and " + std::to_string(y)); EXPECT_FALSE(dom_tree.Dominates(spvtest::GetBasicBlock(fn, x), @@ -108,25 +107,25 @@ TEST_F(PassClassTest, DominatorSimpleCFG) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* fn = spvtest::GetFunction(module, 1); - const ir::BasicBlock* entry = spvtest::GetBasicBlock(fn, 10); + const Function* fn = spvtest::GetFunction(module, 1); + const BasicBlock* entry = spvtest::GetBasicBlock(fn, 10); EXPECT_EQ(entry, fn->entry().get()) << "The entry node is not the expected one"; // Test normal dominator tree { - opt::DominatorAnalysis dom_tree; - ir::CFG cfg(module); - dom_tree.InitializeTree(fn, cfg); + DominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); // Inspect the actual tree - opt::DominatorTree& tree = dom_tree.GetDomTree(); + DominatorTree& tree = dom_tree.GetDomTree(); EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); EXPECT_TRUE( dom_tree.Dominates(cfg.pseudo_entry_block()->id(), entry->id())); @@ -155,8 +154,8 @@ TEST_F(PassClassTest, DominatorSimpleCFG) { // check with some invalid inputs EXPECT_FALSE(dom_tree.Dominates(nullptr, entry)); EXPECT_FALSE(dom_tree.Dominates(entry, nullptr)); - EXPECT_FALSE(dom_tree.Dominates(static_cast(nullptr), - static_cast(nullptr))); + EXPECT_FALSE(dom_tree.Dominates(static_cast(nullptr), + static_cast(nullptr))); EXPECT_FALSE(dom_tree.Dominates(10, 1)); EXPECT_FALSE(dom_tree.Dominates(1, 10)); EXPECT_FALSE(dom_tree.Dominates(1, 1)); @@ -186,12 +185,12 @@ TEST_F(PassClassTest, DominatorSimpleCFG) { // Test post dominator tree { - opt::PostDominatorAnalysis dom_tree; - ir::CFG cfg(module); - dom_tree.InitializeTree(fn, cfg); + PostDominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); // Inspect the actual tree - opt::DominatorTree& tree = dom_tree.GetDomTree(); + DominatorTree& tree = dom_tree.GetDomTree(); EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_exit_block()); EXPECT_TRUE(dom_tree.Dominates(cfg.pseudo_exit_block()->id(), 15)); @@ -217,8 +216,8 @@ TEST_F(PassClassTest, DominatorSimpleCFG) { // check with some invalid inputs EXPECT_FALSE(dom_tree.Dominates(nullptr, entry)); EXPECT_FALSE(dom_tree.Dominates(entry, nullptr)); - EXPECT_FALSE(dom_tree.Dominates(static_cast(nullptr), - static_cast(nullptr))); + EXPECT_FALSE(dom_tree.Dominates(static_cast(nullptr), + static_cast(nullptr))); EXPECT_FALSE(dom_tree.Dominates(10, 1)); EXPECT_FALSE(dom_tree.Dominates(1, 10)); EXPECT_FALSE(dom_tree.Dominates(1, 1)); @@ -274,26 +273,26 @@ TEST_F(PassClassTest, DominatorIrreducibleCFG) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* fn = spvtest::GetFunction(module, 1); + const Function* fn = spvtest::GetFunction(module, 1); - const ir::BasicBlock* entry = spvtest::GetBasicBlock(fn, 8); + const BasicBlock* entry = spvtest::GetBasicBlock(fn, 8); EXPECT_EQ(entry, fn->entry().get()) << "The entry node is not the expected one"; // Check normal dominator tree { - opt::DominatorAnalysis dom_tree; - ir::CFG cfg(module); - dom_tree.InitializeTree(fn, cfg); + DominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); // Inspect the actual tree - opt::DominatorTree& tree = dom_tree.GetDomTree(); + DominatorTree& tree = dom_tree.GetDomTree(); EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); EXPECT_TRUE( dom_tree.Dominates(cfg.pseudo_entry_block()->id(), entry->id())); @@ -330,12 +329,12 @@ TEST_F(PassClassTest, DominatorIrreducibleCFG) { // Check post dominator tree { - opt::PostDominatorAnalysis dom_tree; - ir::CFG cfg(module); - dom_tree.InitializeTree(fn, cfg); + PostDominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); // Inspect the actual tree - opt::DominatorTree& tree = dom_tree.GetDomTree(); + DominatorTree& tree = dom_tree.GetDomTree(); EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_exit_block()); EXPECT_TRUE(dom_tree.Dominates(cfg.pseudo_exit_block()->id(), 12)); @@ -395,26 +394,26 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* fn = spvtest::GetFunction(module, 1); + const Function* fn = spvtest::GetFunction(module, 1); - const ir::BasicBlock* entry = spvtest::GetBasicBlock(fn, 10); + const BasicBlock* entry = spvtest::GetBasicBlock(fn, 10); EXPECT_EQ(entry, fn->entry().get()) << "The entry node is not the expected one"; // Check normal dominator tree { - opt::DominatorAnalysis dom_tree; - ir::CFG cfg(module); - dom_tree.InitializeTree(fn, cfg); + DominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); // Inspect the actual tree - opt::DominatorTree& tree = dom_tree.GetDomTree(); + DominatorTree& tree = dom_tree.GetDomTree(); EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); EXPECT_TRUE( dom_tree.Dominates(cfg.pseudo_entry_block()->id(), entry->id())); @@ -437,8 +436,8 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { std::array node_order = {{10, 11, 12}}; { // Test dominator tree iteration order. - opt::DominatorTree::iterator node_it = dom_tree.GetDomTree().begin(); - opt::DominatorTree::iterator node_end = dom_tree.GetDomTree().end(); + DominatorTree::iterator node_it = dom_tree.GetDomTree().begin(); + DominatorTree::iterator node_end = dom_tree.GetDomTree().end(); for (uint32_t id : node_order) { EXPECT_NE(node_it, node_end); EXPECT_EQ(node_it->id(), id); @@ -448,10 +447,8 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { } { // Same as above, but with const iterators. - opt::DominatorTree::const_iterator node_it = - dom_tree.GetDomTree().cbegin(); - opt::DominatorTree::const_iterator node_end = - dom_tree.GetDomTree().cend(); + DominatorTree::const_iterator node_it = dom_tree.GetDomTree().cbegin(); + DominatorTree::const_iterator node_end = dom_tree.GetDomTree().cend(); for (uint32_t id : node_order) { EXPECT_NE(node_it, node_end); EXPECT_EQ(node_it->id(), id); @@ -461,12 +458,9 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { } { // Test dominator tree iteration order. - opt::DominatorTree::post_iterator node_it = - dom_tree.GetDomTree().post_begin(); - opt::DominatorTree::post_iterator node_end = - dom_tree.GetDomTree().post_end(); - for (uint32_t id : - ir::make_range(node_order.rbegin(), node_order.rend())) { + DominatorTree::post_iterator node_it = dom_tree.GetDomTree().post_begin(); + DominatorTree::post_iterator node_end = dom_tree.GetDomTree().post_end(); + for (uint32_t id : make_range(node_order.rbegin(), node_order.rend())) { EXPECT_NE(node_it, node_end); EXPECT_EQ(node_it->id(), id); node_it++; @@ -475,12 +469,11 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { } { // Same as above, but with const iterators. - opt::DominatorTree::const_post_iterator node_it = + DominatorTree::const_post_iterator node_it = dom_tree.GetDomTree().post_cbegin(); - opt::DominatorTree::const_post_iterator node_end = + DominatorTree::const_post_iterator node_end = dom_tree.GetDomTree().post_cend(); - for (uint32_t id : - ir::make_range(node_order.rbegin(), node_order.rend())) { + for (uint32_t id : make_range(node_order.rbegin(), node_order.rend())) { EXPECT_NE(node_it, node_end); EXPECT_EQ(node_it->id(), id); node_it++; @@ -491,12 +484,12 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { // Check post dominator tree { - opt::PostDominatorAnalysis dom_tree; - ir::CFG cfg(module); - dom_tree.InitializeTree(fn, cfg); + PostDominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); // Inspect the actual tree - opt::DominatorTree& tree = dom_tree.GetDomTree(); + DominatorTree& tree = dom_tree.GetDomTree(); EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_exit_block()); EXPECT_TRUE(dom_tree.Dominates(cfg.pseudo_exit_block()->id(), 12)); @@ -521,8 +514,8 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { std::array node_order = {{12, 11, 10}}; { // Test dominator tree iteration order. - opt::DominatorTree::iterator node_it = tree.begin(); - opt::DominatorTree::iterator node_end = tree.end(); + DominatorTree::iterator node_it = tree.begin(); + DominatorTree::iterator node_end = tree.end(); for (uint32_t id : node_order) { EXPECT_NE(node_it, node_end); EXPECT_EQ(node_it->id(), id); @@ -532,8 +525,8 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { } { // Same as above, but with const iterators. - opt::DominatorTree::const_iterator node_it = tree.cbegin(); - opt::DominatorTree::const_iterator node_end = tree.cend(); + DominatorTree::const_iterator node_it = tree.cbegin(); + DominatorTree::const_iterator node_end = tree.cend(); for (uint32_t id : node_order) { EXPECT_NE(node_it, node_end); EXPECT_EQ(node_it->id(), id); @@ -543,12 +536,9 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { } { // Test dominator tree iteration order. - opt::DominatorTree::post_iterator node_it = - dom_tree.GetDomTree().post_begin(); - opt::DominatorTree::post_iterator node_end = - dom_tree.GetDomTree().post_end(); - for (uint32_t id : - ir::make_range(node_order.rbegin(), node_order.rend())) { + DominatorTree::post_iterator node_it = dom_tree.GetDomTree().post_begin(); + DominatorTree::post_iterator node_end = dom_tree.GetDomTree().post_end(); + for (uint32_t id : make_range(node_order.rbegin(), node_order.rend())) { EXPECT_NE(node_it, node_end); EXPECT_EQ(node_it->id(), id); node_it++; @@ -557,12 +547,11 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { } { // Same as above, but with const iterators. - opt::DominatorTree::const_post_iterator node_it = + DominatorTree::const_post_iterator node_it = dom_tree.GetDomTree().post_cbegin(); - opt::DominatorTree::const_post_iterator node_end = + DominatorTree::const_post_iterator node_end = dom_tree.GetDomTree().post_cend(); - for (uint32_t id : - ir::make_range(node_order.rbegin(), node_order.rend())) { + for (uint32_t id : make_range(node_order.rbegin(), node_order.rend())) { EXPECT_NE(node_it, node_end); EXPECT_EQ(node_it->id(), id); node_it++; @@ -602,26 +591,26 @@ TEST_F(PassClassTest, DominatorUnreachableInLoop) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* fn = spvtest::GetFunction(module, 1); + const Function* fn = spvtest::GetFunction(module, 1); - const ir::BasicBlock* entry = spvtest::GetBasicBlock(fn, 10); + const BasicBlock* entry = spvtest::GetBasicBlock(fn, 10); EXPECT_EQ(entry, fn->entry().get()) << "The entry node is not the expected one"; // Check normal dominator tree { - opt::DominatorAnalysis dom_tree; - ir::CFG cfg(module); - dom_tree.InitializeTree(fn, cfg); + DominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); // Inspect the actual tree - opt::DominatorTree& tree = dom_tree.GetDomTree(); + DominatorTree& tree = dom_tree.GetDomTree(); EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); EXPECT_TRUE( dom_tree.Dominates(cfg.pseudo_entry_block()->id(), entry->id())); @@ -667,9 +656,9 @@ TEST_F(PassClassTest, DominatorUnreachableInLoop) { // Check post dominator tree. { - opt::PostDominatorAnalysis dom_tree; - ir::CFG cfg(module); - dom_tree.InitializeTree(fn, cfg); + PostDominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); // (strict) dominance checks. for (uint32_t id : {10, 11, 12, 13, 14, 15}) @@ -733,25 +722,25 @@ TEST_F(PassClassTest, DominatorInfinitLoop) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* fn = spvtest::GetFunction(module, 1); + const Function* fn = spvtest::GetFunction(module, 1); - const ir::BasicBlock* entry = spvtest::GetBasicBlock(fn, 10); + const BasicBlock* entry = spvtest::GetBasicBlock(fn, 10); EXPECT_EQ(entry, fn->entry().get()) << "The entry node is not the expected one"; // Check normal dominator tree { - opt::DominatorAnalysis dom_tree; - ir::CFG cfg(module); - dom_tree.InitializeTree(fn, cfg); + DominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); // Inspect the actual tree - opt::DominatorTree& tree = dom_tree.GetDomTree(); + DominatorTree& tree = dom_tree.GetDomTree(); EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); EXPECT_TRUE( dom_tree.Dominates(cfg.pseudo_entry_block()->id(), entry->id())); @@ -781,12 +770,12 @@ TEST_F(PassClassTest, DominatorInfinitLoop) { // Check post dominator tree { - opt::PostDominatorAnalysis dom_tree; - ir::CFG cfg(module); - dom_tree.InitializeTree(fn, cfg); + PostDominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); // Inspect the actual tree - opt::DominatorTree& tree = dom_tree.GetDomTree(); + DominatorTree& tree = dom_tree.GetDomTree(); EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_exit_block()); EXPECT_TRUE(dom_tree.Dominates(cfg.pseudo_exit_block()->id(), 12)); @@ -837,25 +826,26 @@ TEST_F(PassClassTest, DominatorUnreachableFromEntry) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* fn = spvtest::GetFunction(module, 1); + const Function* fn = spvtest::GetFunction(module, 1); - const ir::BasicBlock* entry = spvtest::GetBasicBlock(fn, 8); + const BasicBlock* entry = spvtest::GetBasicBlock(fn, 8); EXPECT_EQ(entry, fn->entry().get()) << "The entry node is not the expected one"; // Check dominator tree { - opt::DominatorAnalysis dom_tree; - ir::CFG cfg(module); - dom_tree.InitializeTree(fn, cfg); + DominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); + // Inspect the actual tree - opt::DominatorTree& tree = dom_tree.GetDomTree(); + DominatorTree& tree = dom_tree.GetDomTree(); EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); EXPECT_TRUE( dom_tree.Dominates(cfg.pseudo_entry_block()->id(), entry->id())); @@ -879,12 +869,12 @@ TEST_F(PassClassTest, DominatorUnreachableFromEntry) { // Check post dominator tree { - opt::PostDominatorAnalysis dom_tree; - ir::CFG cfg(module); - dom_tree.InitializeTree(fn, cfg); + PostDominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); // Inspect the actual tree - opt::DominatorTree& tree = dom_tree.GetDomTree(); + DominatorTree& tree = dom_tree.GetDomTree(); EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_exit_block()); EXPECT_TRUE(dom_tree.Dominates(cfg.pseudo_exit_block()->id(), 9)); @@ -906,3 +896,5 @@ TEST_F(PassClassTest, DominatorUnreachableFromEntry) { } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dominator_tree/nested_ifs.cpp b/3rdparty/spirv-tools/test/opt/dominator_tree/nested_ifs.cpp index fbb2d12cc..0552b7580 100644 --- a/3rdparty/spirv-tools/test/opt/dominator_tree/nested_ifs.cpp +++ b/3rdparty/spirv-tools/test/opt/dominator_tree/nested_ifs.cpp @@ -12,24 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include #include -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" -#include "opt/dominator_analysis.h" -#include "opt/pass.h" +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -110,17 +110,16 @@ TEST_F(PassClassTest, UnreachableNestedIfs) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 4); + const Function* f = spvtest::GetFunction(module, 4); - ir::CFG cfg(module); - opt::DominatorAnalysis* analysis = context->GetDominatorAnalysis(f, cfg); + DominatorAnalysis* analysis = context->GetDominatorAnalysis(f); EXPECT_TRUE(analysis->Dominates(5, 8)); EXPECT_TRUE(analysis->Dominates(5, 9)); @@ -150,3 +149,5 @@ TEST_F(PassClassTest, UnreachableNestedIfs) { } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dominator_tree/nested_ifs_post.cpp b/3rdparty/spirv-tools/test/opt/dominator_tree/nested_ifs_post.cpp index a112af99c..ad759df86 100644 --- a/3rdparty/spirv-tools/test/opt/dominator_tree/nested_ifs_post.cpp +++ b/3rdparty/spirv-tools/test/opt/dominator_tree/nested_ifs_post.cpp @@ -12,24 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include #include #include -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" -#include "opt/dominator_analysis.h" -#include "opt/pass.h" +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -110,18 +109,16 @@ TEST_F(PassClassTest, UnreachableNestedIfs) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 4); + const Function* f = spvtest::GetFunction(module, 4); - ir::CFG cfg(module); - opt::PostDominatorAnalysis* analysis = - context->GetPostDominatorAnalysis(f, cfg); + PostDominatorAnalysis* analysis = context->GetPostDominatorAnalysis(f); EXPECT_TRUE(analysis->Dominates(5, 5)); EXPECT_TRUE(analysis->Dominates(8, 8)); @@ -155,3 +152,5 @@ TEST_F(PassClassTest, UnreachableNestedIfs) { } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dominator_tree/nested_loops.cpp b/3rdparty/spirv-tools/test/opt/dominator_tree/nested_loops.cpp index 4e5c072c8..7d03937b1 100644 --- a/3rdparty/spirv-tools/test/opt/dominator_tree/nested_loops.cpp +++ b/3rdparty/spirv-tools/test/opt/dominator_tree/nested_loops.cpp @@ -12,24 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include #include #include -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" -#include "opt/dominator_analysis.h" -#include "opt/pass.h" +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -351,16 +350,15 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 4); - ir::CFG cfg(module); - opt::DominatorAnalysis* analysis = context->GetDominatorAnalysis(f, cfg); + const Function* f = spvtest::GetFunction(module, 4); + DominatorAnalysis* analysis = context->GetDominatorAnalysis(f); EXPECT_TRUE(analysis->Dominates(5, 10)); EXPECT_TRUE(analysis->Dominates(5, 46)); @@ -431,3 +429,5 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dominator_tree/nested_loops_with_unreachables.cpp b/3rdparty/spirv-tools/test/opt/dominator_tree/nested_loops_with_unreachables.cpp index ab1b6a796..e87e8ddab 100644 --- a/3rdparty/spirv-tools/test/opt/dominator_tree/nested_loops_with_unreachables.cpp +++ b/3rdparty/spirv-tools/test/opt/dominator_tree/nested_loops_with_unreachables.cpp @@ -12,21 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include -#include - -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" -#include "opt/dominator_analysis.h" -#include "opt/pass.h" +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; using PassClassTest = PassTest<::testing::Test>; @@ -279,16 +280,15 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 4); - ir::CFG cfg(module); - opt::DominatorAnalysis* analysis = context->GetDominatorAnalysis(f, cfg); + const Function* f = spvtest::GetFunction(module, 4); + DominatorAnalysis* analysis = context->GetDominatorAnalysis(f); EXPECT_TRUE(analysis->Dominates(5, 10)); EXPECT_TRUE(analysis->Dominates(5, 14)); @@ -844,3 +844,5 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dominator_tree/post.cpp b/3rdparty/spirv-tools/test/opt/dominator_tree/post.cpp index 2c786aee7..bb10fdef1 100644 --- a/3rdparty/spirv-tools/test/opt/dominator_tree/post.cpp +++ b/3rdparty/spirv-tools/test/opt/dominator_tree/post.cpp @@ -12,23 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include -#include - -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" -#include "opt/dominator_analysis.h" -#include "opt/pass.h" +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -139,17 +139,16 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 4); - ir::CFG cfg(module); - opt::PostDominatorAnalysis* analysis = - context->GetPostDominatorAnalysis(f, cfg); + const Function* f = spvtest::GetFunction(module, 4); + CFG cfg(module); + PostDominatorAnalysis* analysis = context->GetPostDominatorAnalysis(f); EXPECT_TRUE(analysis->Dominates(19, 18)); EXPECT_TRUE(analysis->Dominates(19, 5)); @@ -204,3 +203,5 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dominator_tree/simple.cpp b/3rdparty/spirv-tools/test/opt/dominator_tree/simple.cpp index 4cba7dc8a..d11854d55 100644 --- a/3rdparty/spirv-tools/test/opt/dominator_tree/simple.cpp +++ b/3rdparty/spirv-tools/test/opt/dominator_tree/simple.cpp @@ -12,23 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include -#include - -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" -#include "opt/dominator_analysis.h" -#include "opt/pass.h" +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -139,18 +139,18 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 4); + const Function* f = spvtest::GetFunction(module, 4); - ir::CFG cfg(module); - opt::DominatorAnalysis* analysis = context->GetDominatorAnalysis(f, cfg); + DominatorAnalysis* analysis = context->GetDominatorAnalysis(f); + const CFG& cfg = *context->cfg(); - opt::DominatorTree& tree = analysis->GetDomTree(); + DominatorTree& tree = analysis->GetDomTree(); EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); EXPECT_TRUE(analysis->Dominates(5, 18)); @@ -173,3 +173,5 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dominator_tree/switch_case_fallthrough.cpp b/3rdparty/spirv-tools/test/opt/dominator_tree/switch_case_fallthrough.cpp index 2b2e27ba4..d9dd7d161 100644 --- a/3rdparty/spirv-tools/test/opt/dominator_tree/switch_case_fallthrough.cpp +++ b/3rdparty/spirv-tools/test/opt/dominator_tree/switch_case_fallthrough.cpp @@ -12,23 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include -#include - -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" -#include "opt/dominator_analysis.h" -#include "opt/pass.h" +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -130,16 +130,15 @@ TEST_F(PassClassTest, UnreachableNestedIfs) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 4); - ir::CFG cfg(module); - opt::DominatorAnalysis* analysis = context->GetDominatorAnalysis(f, cfg); + const Function* f = spvtest::GetFunction(module, 4); + DominatorAnalysis* analysis = context->GetDominatorAnalysis(f); EXPECT_TRUE(analysis->Dominates(5, 5)); EXPECT_TRUE(analysis->Dominates(5, 17)); @@ -160,3 +159,5 @@ TEST_F(PassClassTest, UnreachableNestedIfs) { } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dominator_tree/unreachable_for.cpp b/3rdparty/spirv-tools/test/opt/dominator_tree/unreachable_for.cpp index a720cf4ea..469e5c142 100644 --- a/3rdparty/spirv-tools/test/opt/dominator_tree/unreachable_for.cpp +++ b/3rdparty/spirv-tools/test/opt/dominator_tree/unreachable_for.cpp @@ -12,23 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include -#include - -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" -#include "opt/dominator_analysis.h" -#include "opt/pass.h" +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -81,16 +81,15 @@ TEST_F(PassClassTest, UnreachableNestedIfs) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 4); - ir::CFG cfg(module); - opt::DominatorAnalysis* analysis = context->GetDominatorAnalysis(f, cfg); + const Function* f = spvtest::GetFunction(module, 4); + DominatorAnalysis* analysis = context->GetDominatorAnalysis(f); EXPECT_TRUE(analysis->Dominates(5, 5)); EXPECT_TRUE(analysis->Dominates(5, 10)); EXPECT_TRUE(analysis->Dominates(5, 14)); @@ -118,3 +117,5 @@ TEST_F(PassClassTest, UnreachableNestedIfs) { } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/dominator_tree/unreachable_for_post.cpp b/3rdparty/spirv-tools/test/opt/dominator_tree/unreachable_for_post.cpp index 18a112ee2..8d3e37b4a 100644 --- a/3rdparty/spirv-tools/test/opt/dominator_tree/unreachable_for_post.cpp +++ b/3rdparty/spirv-tools/test/opt/dominator_tree/unreachable_for_post.cpp @@ -12,23 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include -#include - -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" -#include "opt/dominator_analysis.h" -#include "opt/pass.h" +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -81,18 +81,16 @@ TEST_F(PassClassTest, UnreachableNestedIfs) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 4); + const Function* f = spvtest::GetFunction(module, 4); - ir::CFG cfg(module); - opt::PostDominatorAnalysis* analysis = - context->GetPostDominatorAnalysis(f, cfg); + PostDominatorAnalysis* analysis = context->GetPostDominatorAnalysis(f); EXPECT_TRUE(analysis->Dominates(12, 12)); EXPECT_TRUE(analysis->Dominates(12, 14)); @@ -116,3 +114,5 @@ TEST_F(PassClassTest, UnreachableNestedIfs) { } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/eliminate_dead_const_test.cpp b/3rdparty/spirv-tools/test/opt/eliminate_dead_const_test.cpp index 89613384e..7fac866ce 100644 --- a/3rdparty/spirv-tools/test/opt/eliminate_dead_const_test.cpp +++ b/3rdparty/spirv-tools/test/opt/eliminate_dead_const_test.cpp @@ -12,20 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "assembly_builder.h" -#include "pass_fixture.h" -#include "pass_utils.h" - #include #include #include #include +#include #include +#include +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using EliminateDeadConstantBasicTest = PassTest<::testing::Test>; TEST_F(EliminateDeadConstantBasicTest, BasicAllDeadConstants) { @@ -73,7 +75,7 @@ TEST_F(EliminateDeadConstantBasicTest, BasicAllDeadConstants) { }); }); - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(text), expected_disassembly, /* skip_nop = */ true); } @@ -129,7 +131,7 @@ TEST_F(EliminateDeadConstantBasicTest, BasicNoneDeadConstants) { // clang-format on }; // All constants are used, so none of them should be eliminated. - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(text), JoinAllInsts(text), /* skip_nop = */ true); } @@ -191,7 +193,7 @@ TEST_P(EliminateDeadConstantTest, Custom) { const std::string expected = builder.GetCode(); builder.AppendTypesConstantsGlobals(tc.dead_consts); const std::string assembly_with_dead_const = builder.GetCode(); - SinglePassRunAndCheck( + SinglePassRunAndCheck( assembly_with_dead_const, expected, /* skip_nop = */ true); } @@ -839,4 +841,7 @@ INSTANTIATE_TEST_CASE_P( // Long Def-Use chain with swizzle // clang-format on }))); -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/eliminate_dead_functions_test.cpp b/3rdparty/spirv-tools/test/opt/eliminate_dead_functions_test.cpp index c780717d8..0a3d490a8 100644 --- a/3rdparty/spirv-tools/test/opt/eliminate_dead_functions_test.cpp +++ b/3rdparty/spirv-tools/test/opt/eliminate_dead_functions_test.cpp @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include -#include - -#include "assembly_builder.h" -#include "pass_fixture.h" -#include "pass_utils.h" +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::HasSubstr; - using EliminateDeadFunctionsBasicTest = PassTest<::testing::Test>; TEST_F(EliminateDeadFunctionsBasicTest, BasicDeleteDeadFunction) { @@ -61,7 +61,7 @@ TEST_F(EliminateDeadFunctionsBasicTest, BasicDeleteDeadFunction) { }; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(Concat(common_code, dead_function)), JoinAllInsts(common_code), /* skip_nop = */ true); } @@ -98,9 +98,9 @@ TEST_F(EliminateDeadFunctionsBasicTest, BasicKeepLiveFunction) { SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); std::string assembly = JoinAllInsts(text); - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( assembly, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); EXPECT_EQ(assembly, std::get<0>(result)); } @@ -137,9 +137,9 @@ TEST_F(EliminateDeadFunctionsBasicTest, BasicKeepExportFunctions) { SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); std::string assembly = JoinAllInsts(text); - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( assembly, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); EXPECT_EQ(assembly, std::get<0>(result)); } @@ -200,7 +200,10 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(text, expected_output, - /* skip_nop = */ true); + SinglePassRunAndCheck(text, expected_output, + /* skip_nop = */ true); } -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/feature_manager_test.cpp b/3rdparty/spirv-tools/test/opt/feature_manager_test.cpp index f67388a8c..767376cf5 100644 --- a/3rdparty/spirv-tools/test/opt/feature_manager_test.cpp +++ b/3rdparty/spirv-tools/test/opt/feature_manager_test.cpp @@ -12,14 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include #include +#include +#include -#include "opt/build_module.h" -#include "opt/ir_context.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/ir_context.h" -using namespace spvtools; +namespace spvtools { +namespace opt { +namespace { using FeatureManagerTest = ::testing::Test; @@ -29,12 +33,12 @@ OpCapability Shader OpMemoryModel Logical GLSL450 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); ASSERT_NE(context, nullptr); EXPECT_FALSE(context->get_feature_mgr()->HasExtension( - libspirv::Extension::kSPV_KHR_variable_pointers)); + Extension::kSPV_KHR_variable_pointers)); } TEST_F(FeatureManagerTest, OneExtension) { @@ -44,12 +48,12 @@ OpMemoryModel Logical GLSL450 OpExtension "SPV_KHR_variable_pointers" )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); ASSERT_NE(context, nullptr); EXPECT_TRUE(context->get_feature_mgr()->HasExtension( - libspirv::Extension::kSPV_KHR_variable_pointers)); + Extension::kSPV_KHR_variable_pointers)); } TEST_F(FeatureManagerTest, NotADifferentExtension) { @@ -59,12 +63,12 @@ OpMemoryModel Logical GLSL450 OpExtension "SPV_KHR_variable_pointers" )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); ASSERT_NE(context, nullptr); EXPECT_FALSE(context->get_feature_mgr()->HasExtension( - libspirv::Extension::kSPV_KHR_storage_buffer_storage_class)); + Extension::kSPV_KHR_storage_buffer_storage_class)); } TEST_F(FeatureManagerTest, TwoExtensions) { @@ -75,14 +79,14 @@ OpExtension "SPV_KHR_variable_pointers" OpExtension "SPV_KHR_storage_buffer_storage_class" )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); ASSERT_NE(context, nullptr); EXPECT_TRUE(context->get_feature_mgr()->HasExtension( - libspirv::Extension::kSPV_KHR_variable_pointers)); + Extension::kSPV_KHR_variable_pointers)); EXPECT_TRUE(context->get_feature_mgr()->HasExtension( - libspirv::Extension::kSPV_KHR_storage_buffer_storage_class)); + Extension::kSPV_KHR_storage_buffer_storage_class)); } // Test capability checks. @@ -92,7 +96,7 @@ OpCapability Shader OpMemoryModel Logical GLSL450 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); ASSERT_NE(context, nullptr); @@ -106,7 +110,7 @@ OpCapability Kernel OpMemoryModel Logical GLSL450 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); ASSERT_NE(context, nullptr); @@ -120,7 +124,7 @@ OpCapability Tessellation OpMemoryModel Logical GLSL450 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); ASSERT_NE(context, nullptr); @@ -132,3 +136,7 @@ OpMemoryModel Logical GLSL450 EXPECT_TRUE(context->get_feature_mgr()->HasCapability(SpvCapabilityMatrix)); EXPECT_FALSE(context->get_feature_mgr()->HasCapability(SpvCapabilityKernel)); } + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/flatten_decoration_test.cpp b/3rdparty/spirv-tools/test/opt/flatten_decoration_test.cpp index 8e6d979ad..483ee6e53 100644 --- a/3rdparty/spirv-tools/test/opt/flatten_decoration_test.cpp +++ b/3rdparty/spirv-tools/test/opt/flatten_decoration_test.cpp @@ -13,15 +13,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include +#include -#include "pass_fixture.h" -#include "pass_utils.h" +#include "gmock/gmock.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - // Returns the initial part of the assembly text for a valid // SPIR-V module, including instructions prior to decorations. std::string PreambleAssembly() { @@ -75,7 +77,7 @@ TEST_P(FlattenDecorationTest, TransformsDecorations) { const auto after = PreambleAssembly() + GetParam().expected + TypesAndFunctionsAssembly(); - SinglePassRunAndCheck(before, after, false, true); + SinglePassRunAndCheck(before, after, false, true); } INSTANTIATE_TEST_CASE_P(NoUses, FlattenDecorationTest, @@ -231,4 +233,6 @@ INSTANTIATE_TEST_CASE_P(UnrelatedDecorations, FlattenDecorationTest, "OpMemberDecorate %Point 1 Offset 4\n"}, }), ); -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/fold_spec_const_op_composite_test.cpp b/3rdparty/spirv-tools/test/opt/fold_spec_const_op_composite_test.cpp index a8debdf76..8ecfd5c78 100644 --- a/3rdparty/spirv-tools/test/opt/fold_spec_const_op_composite_test.cpp +++ b/3rdparty/spirv-tools/test/opt/fold_spec_const_op_composite_test.cpp @@ -12,20 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "assembly_builder.h" - #include +#include +#include -#include "pass_fixture.h" -#include "pass_utils.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using FoldSpecConstantOpAndCompositePassBasicTest = PassTest<::testing::Test>; TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, Empty) { - SinglePassRunAndCheck( + SinglePassRunAndCheck( "", "", /* skip_nop = */ true); } @@ -73,7 +75,7 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, Basic) { "OpFunctionEnd", // clang-format on }; - SinglePassRunAndCheck( + SinglePassRunAndCheck( builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true); } @@ -99,7 +101,7 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, // clang-format on }); - SinglePassRunAndCheck( + SinglePassRunAndCheck( builder.GetCode(), builder.GetCode(), /* skip_nop = */ true); } @@ -206,16 +208,16 @@ TEST_P(FoldSpecConstantOpAndCompositePassTest, ParamTestCase) { // Run the optimization and get the result code in disassembly. std::string optimized; - auto status = opt::Pass::Status::SuccessWithoutChange; + auto status = Pass::Status::SuccessWithoutChange; std::tie(optimized, status) = - SinglePassRunAndDisassemble( + SinglePassRunAndDisassemble( original, /* skip_nop = */ true, /* do_validation = */ false); // Check the optimized code, but ignore the OpName instructions. - EXPECT_NE(opt::Pass::Status::Failure, status); + EXPECT_NE(Pass::Status::Failure, status); EXPECT_EQ( StripOpNameInstructions(expected) == StripOpNameInstructions(original), - status == opt::Pass::Status::SuccessWithoutChange); + status == Pass::Status::SuccessWithoutChange); EXPECT_EQ(StripOpNameInstructions(expected), StripOpNameInstructions(optimized)); } @@ -441,13 +443,13 @@ INSTANTIATE_TEST_CASE_P( { "%true = OpConstantTrue %bool", "%true_0 = OpConstantTrue %bool", - "%spec_bool_t_vec = OpConstantComposite %v2bool %true_0 %true_0", + "%spec_bool_t_vec = OpConstantComposite %v2bool %bool_true %bool_true", "%false = OpConstantFalse %bool", "%false_0 = OpConstantFalse %bool", - "%spec_bool_f_vec = OpConstantComposite %v2bool %false_0 %false_0", + "%spec_bool_f_vec = OpConstantComposite %v2bool %bool_false %bool_false", "%false_1 = OpConstantFalse %bool", "%false_2 = OpConstantFalse %bool", - "%spec_bool_from_null = OpConstantComposite %v2bool %false_2 %false_2", + "%spec_bool_from_null = OpConstantComposite %v2bool %bool_false %bool_false", }, }, @@ -463,13 +465,13 @@ INSTANTIATE_TEST_CASE_P( { "%true = OpConstantTrue %bool", "%true_0 = OpConstantTrue %bool", - "%spec_bool_t_vec = OpConstantComposite %v2bool %true_0 %true_0", + "%spec_bool_t_vec = OpConstantComposite %v2bool %bool_true %bool_true", "%false = OpConstantFalse %bool", "%false_0 = OpConstantFalse %bool", - "%spec_bool_f_vec = OpConstantComposite %v2bool %false_0 %false_0", + "%spec_bool_f_vec = OpConstantComposite %v2bool %bool_false %bool_false", "%false_1 = OpConstantFalse %bool", "%false_2 = OpConstantFalse %bool", - "%spec_bool_from_null = OpConstantComposite %v2bool %false_2 %false_2", + "%spec_bool_from_null = OpConstantComposite %v2bool %bool_false %bool_false", }, }, @@ -485,13 +487,13 @@ INSTANTIATE_TEST_CASE_P( { "%int_1 = OpConstant %int 1", "%int_1_0 = OpConstant %int 1", - "%spec_int_one_vec = OpConstantComposite %v2int %int_1_0 %int_1_0", + "%spec_int_one_vec = OpConstantComposite %v2int %signed_one %signed_one", "%int_0 = OpConstant %int 0", "%int_0_0 = OpConstant %int 0", - "%spec_int_zero_vec = OpConstantComposite %v2int %int_0_0 %int_0_0", + "%spec_int_zero_vec = OpConstantComposite %v2int %signed_zero %signed_zero", "%int_0_1 = OpConstant %int 0", "%int_0_2 = OpConstant %int 0", - "%spec_int_from_null = OpConstantComposite %v2int %int_0_2 %int_0_2", + "%spec_int_from_null = OpConstantComposite %v2int %signed_zero %signed_zero", }, }, @@ -507,13 +509,13 @@ INSTANTIATE_TEST_CASE_P( { "%int_1 = OpConstant %int 1", "%int_1_0 = OpConstant %int 1", - "%spec_int_one_vec = OpConstantComposite %v2int %int_1_0 %int_1_0", + "%spec_int_one_vec = OpConstantComposite %v2int %signed_one %signed_one", "%int_0 = OpConstant %int 0", "%int_0_0 = OpConstant %int 0", - "%spec_int_zero_vec = OpConstantComposite %v2int %int_0_0 %int_0_0", + "%spec_int_zero_vec = OpConstantComposite %v2int %signed_zero %signed_zero", "%int_0_1 = OpConstant %int 0", "%int_0_2 = OpConstant %int 0", - "%spec_int_from_null = OpConstantComposite %v2int %int_0_2 %int_0_2", + "%spec_int_from_null = OpConstantComposite %v2int %signed_zero %signed_zero", }, }, @@ -529,13 +531,13 @@ INSTANTIATE_TEST_CASE_P( { "%uint_1 = OpConstant %uint 1", "%uint_1_0 = OpConstant %uint 1", - "%spec_uint_one_vec = OpConstantComposite %v2uint %uint_1_0 %uint_1_0", + "%spec_uint_one_vec = OpConstantComposite %v2uint %unsigned_one %unsigned_one", "%uint_0 = OpConstant %uint 0", "%uint_0_0 = OpConstant %uint 0", - "%spec_uint_zero_vec = OpConstantComposite %v2uint %uint_0_0 %uint_0_0", + "%spec_uint_zero_vec = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero", "%uint_0_1 = OpConstant %uint 0", "%uint_0_2 = OpConstant %uint 0", - "%spec_uint_from_null = OpConstantComposite %v2uint %uint_0_2 %uint_0_2", + "%spec_uint_from_null = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero", }, }, @@ -551,13 +553,13 @@ INSTANTIATE_TEST_CASE_P( { "%uint_1 = OpConstant %uint 1", "%uint_1_0 = OpConstant %uint 1", - "%spec_uint_one_vec = OpConstantComposite %v2uint %uint_1_0 %uint_1_0", + "%spec_uint_one_vec = OpConstantComposite %v2uint %unsigned_one %unsigned_one", "%uint_0 = OpConstant %uint 0", "%uint_0_0 = OpConstant %uint 0", - "%spec_uint_zero_vec = OpConstantComposite %v2uint %uint_0_0 %uint_0_0", + "%spec_uint_zero_vec = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero", "%uint_0_1 = OpConstant %uint 0", "%uint_0_2 = OpConstant %uint 0", - "%spec_uint_from_null = OpConstantComposite %v2uint %uint_0_2 %uint_0_2", + "%spec_uint_from_null = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero", }, }, // clang-format on @@ -836,13 +838,13 @@ INSTANTIATE_TEST_CASE_P( { "%int_n1 = OpConstant %int -1", "%int_n1_0 = OpConstant %int -1", - "%v2int_minus_1 = OpConstantComposite %v2int %int_n1_0 %int_n1_0", + "%v2int_minus_1 = OpConstantComposite %v2int %int_n1 %int_n1", "%int_n2 = OpConstant %int -2", "%int_n2_0 = OpConstant %int -2", - "%v2int_minus_2 = OpConstantComposite %v2int %int_n2_0 %int_n2_0", + "%v2int_minus_2 = OpConstantComposite %v2int %int_n2 %int_n2", "%int_0 = OpConstant %int 0", "%int_0_0 = OpConstant %int 0", - "%v2int_neg_null = OpConstantComposite %v2int %int_0_0 %int_0_0", + "%v2int_neg_null = OpConstantComposite %v2int %signed_zero %signed_zero", }, }, // vector integer (including null vetors) add, sub, div, mul @@ -865,35 +867,35 @@ INSTANTIATE_TEST_CASE_P( { "%int_5 = OpConstant %int 5", "%int_5_0 = OpConstant %int 5", - "%spec_v2int_iadd = OpConstantComposite %v2int %int_5_0 %int_5_0", + "%spec_v2int_iadd = OpConstantComposite %v2int %int_5 %int_5", "%int_n4 = OpConstant %int -4", "%int_n4_0 = OpConstant %int -4", - "%spec_v2int_isub = OpConstantComposite %v2int %int_n4_0 %int_n4_0", + "%spec_v2int_isub = OpConstantComposite %v2int %int_n4 %int_n4", "%int_n2 = OpConstant %int -2", "%int_n2_0 = OpConstant %int -2", - "%spec_v2int_sdiv = OpConstantComposite %v2int %int_n2_0 %int_n2_0", + "%spec_v2int_sdiv = OpConstantComposite %v2int %int_n2 %int_n2", "%int_n6 = OpConstant %int -6", "%int_n6_0 = OpConstant %int -6", - "%spec_v2int_imul = OpConstantComposite %v2int %int_n6_0 %int_n6_0", + "%spec_v2int_imul = OpConstantComposite %v2int %int_n6 %int_n6", "%int_n6_1 = OpConstant %int -6", "%int_n6_2 = OpConstant %int -6", - "%spec_v2int_iadd_null = OpConstantComposite %v2int %int_n6_2 %int_n6_2", + "%spec_v2int_iadd_null = OpConstantComposite %v2int %int_n6 %int_n6", "%uint_5 = OpConstant %uint 5", "%uint_5_0 = OpConstant %uint 5", - "%spec_v2uint_iadd = OpConstantComposite %v2uint %uint_5_0 %uint_5_0", + "%spec_v2uint_iadd = OpConstantComposite %v2uint %uint_5 %uint_5", "%uint_4294967292 = OpConstant %uint 4294967292", "%uint_4294967292_0 = OpConstant %uint 4294967292", - "%spec_v2uint_isub = OpConstantComposite %v2uint %uint_4294967292_0 %uint_4294967292_0", + "%spec_v2uint_isub = OpConstantComposite %v2uint %uint_4294967292 %uint_4294967292", "%uint_1431655764 = OpConstant %uint 1431655764", "%uint_1431655764_0 = OpConstant %uint 1431655764", - "%spec_v2uint_udiv = OpConstantComposite %v2uint %uint_1431655764_0 %uint_1431655764_0", + "%spec_v2uint_udiv = OpConstantComposite %v2uint %uint_1431655764 %uint_1431655764", "%uint_2863311528 = OpConstant %uint 2863311528", "%uint_2863311528_0 = OpConstant %uint 2863311528", - "%spec_v2uint_imul = OpConstantComposite %v2uint %uint_2863311528_0 %uint_2863311528_0", + "%spec_v2uint_imul = OpConstantComposite %v2uint %uint_2863311528 %uint_2863311528", "%uint_2863311528_1 = OpConstant %uint 2863311528", "%uint_2863311528_2 = OpConstant %uint 2863311528", - "%spec_v2uint_isub_null = OpConstantComposite %v2uint %uint_2863311528_2 %uint_2863311528_2", + "%spec_v2uint_isub_null = OpConstantComposite %v2uint %uint_2863311528 %uint_2863311528", }, }, // vector integer rem, mod @@ -938,33 +940,33 @@ INSTANTIATE_TEST_CASE_P( // srem "%int_1 = OpConstant %int 1", "%int_1_0 = OpConstant %int 1", - "%7_srem_3 = OpConstantComposite %v2int %int_1_0 %int_1_0", + "%7_srem_3 = OpConstantComposite %v2int %signed_one %signed_one", "%int_n1 = OpConstant %int -1", "%int_n1_0 = OpConstant %int -1", - "%minus_7_srem_3 = OpConstantComposite %v2int %int_n1_0 %int_n1_0", + "%minus_7_srem_3 = OpConstantComposite %v2int %int_n1 %int_n1", "%int_1_1 = OpConstant %int 1", "%int_1_2 = OpConstant %int 1", - "%7_srem_minus_3 = OpConstantComposite %v2int %int_1_2 %int_1_2", + "%7_srem_minus_3 = OpConstantComposite %v2int %signed_one %signed_one", "%int_n1_1 = OpConstant %int -1", "%int_n1_2 = OpConstant %int -1", - "%minus_7_srem_minus_3 = OpConstantComposite %v2int %int_n1_2 %int_n1_2", + "%minus_7_srem_minus_3 = OpConstantComposite %v2int %int_n1 %int_n1", // smod "%int_1_3 = OpConstant %int 1", "%int_1_4 = OpConstant %int 1", - "%7_smod_3 = OpConstantComposite %v2int %int_1_4 %int_1_4", + "%7_smod_3 = OpConstantComposite %v2int %signed_one %signed_one", "%int_2 = OpConstant %int 2", "%int_2_0 = OpConstant %int 2", - "%minus_7_smod_3 = OpConstantComposite %v2int %int_2_0 %int_2_0", + "%minus_7_smod_3 = OpConstantComposite %v2int %signed_two %signed_two", "%int_n2 = OpConstant %int -2", "%int_n2_0 = OpConstant %int -2", - "%7_smod_minus_3 = OpConstantComposite %v2int %int_n2_0 %int_n2_0", + "%7_smod_minus_3 = OpConstantComposite %v2int %int_n2 %int_n2", "%int_n1_3 = OpConstant %int -1", "%int_n1_4 = OpConstant %int -1", - "%minus_7_smod_minus_3 = OpConstantComposite %v2int %int_n1_4 %int_n1_4", + "%minus_7_smod_minus_3 = OpConstantComposite %v2int %int_n1 %int_n1", // umod "%uint_1 = OpConstant %uint 1", "%uint_1_0 = OpConstant %uint 1", - "%7_umod_3 = OpConstantComposite %v2uint %uint_1_0 %uint_1_0", + "%7_umod_3 = OpConstantComposite %v2uint %unsigned_one %unsigned_one", }, }, // vector integer bitwise, shift @@ -985,25 +987,25 @@ INSTANTIATE_TEST_CASE_P( { "%int_2 = OpConstant %int 2", "%int_2_0 = OpConstant %int 2", - "%xor_1_3 = OpConstantComposite %v2int %int_2_0 %int_2_0", + "%xor_1_3 = OpConstantComposite %v2int %signed_two %signed_two", "%int_0 = OpConstant %int 0", "%int_0_0 = OpConstant %int 0", - "%and_1_2 = OpConstantComposite %v2int %int_0_0 %int_0_0", + "%and_1_2 = OpConstantComposite %v2int %signed_zero %signed_zero", "%int_3 = OpConstant %int 3", "%int_3_0 = OpConstant %int 3", - "%or_1_2 = OpConstantComposite %v2int %int_3_0 %int_3_0", + "%or_1_2 = OpConstantComposite %v2int %signed_three %signed_three", "%unsigned_31 = OpConstant %uint 31", "%v2unsigned_31 = OpConstantComposite %v2uint %unsigned_31 %unsigned_31", "%uint_2147483648 = OpConstant %uint 2147483648", "%uint_2147483648_0 = OpConstant %uint 2147483648", - "%unsigned_left_shift_max = OpConstantComposite %v2uint %uint_2147483648_0 %uint_2147483648_0", + "%unsigned_left_shift_max = OpConstantComposite %v2uint %uint_2147483648 %uint_2147483648", "%uint_1 = OpConstant %uint 1", "%uint_1_0 = OpConstant %uint 1", - "%unsigned_right_shift_logical = OpConstantComposite %v2uint %uint_1_0 %uint_1_0", + "%unsigned_right_shift_logical = OpConstantComposite %v2uint %unsigned_one %unsigned_one", "%int_n1 = OpConstant %int -1", "%int_n1_0 = OpConstant %int -1", - "%signed_right_shift_arithmetic = OpConstantComposite %v2int %int_n1_0 %int_n1_0", + "%signed_right_shift_arithmetic = OpConstantComposite %v2int %int_n1 %int_n1", }, }, // Skip folding if any vector operands or components of the operands @@ -1255,13 +1257,13 @@ INSTANTIATE_TEST_CASE_P( // expected { "%60 = OpConstantNull %int", - "%a = OpConstantComposite %v2int %60 %60", + "%a = OpConstantComposite %v2int %signed_null %signed_null", "%62 = OpConstantNull %int", "%b = OpConstantComposite %v2int %signed_zero %signed_one", "%64 = OpConstantNull %int", - "%c = OpConstantComposite %v2int %signed_three %64", + "%c = OpConstantComposite %v2int %signed_three %signed_null", "%66 = OpConstantNull %int", - "%d = OpConstantComposite %v2int %66 %66", + "%d = OpConstantComposite %v2int %signed_null %signed_null", } }, // skip if any of the components of the vector operands do not have @@ -1377,7 +1379,7 @@ INSTANTIATE_TEST_CASE_P( "%used_vec_a = OpConstantComposite %v2int %spec_int_18 %spec_int_19", "%int_10201 = OpConstant %int 10201", "%int_1 = OpConstant %int 1", - "%used_vec_b = OpConstantComposite %v2int %int_10201 %int_1", + "%used_vec_b = OpConstantComposite %v2int %int_10201 %signed_one", "%spec_int_21 = OpConstant %int 10201", "%array = OpConstantComposite %type_arr_int_4 %spec_int_20 %spec_int_20 %spec_int_21 %spec_int_21", "%spec_int_22 = OpSpecConstant %int 123", @@ -1386,4 +1388,7 @@ INSTANTIATE_TEST_CASE_P( }, // Long Def-Use chain with swizzle }))); -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/fold_test.cpp b/3rdparty/spirv-tools/test/opt/fold_test.cpp index c33bbd99c..b1e575886 100644 --- a/3rdparty/spirv-tools/test/opt/fold_test.cpp +++ b/3rdparty/spirv-tools/test/opt/fold_test.cpp @@ -11,33 +11,34 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include +#include #include +#include -#include -#include -#include +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/fold.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "spirv-tools/libspirv.hpp" +#include "test/opt/pass_utils.h" #ifdef SPIRV_EFFCEE #include "effcee/effcee.h" #endif -#include "opt/build_module.h" -#include "opt/def_use_manager.h" -#include "opt/ir_context.h" -#include "opt/module.h" -#include "pass_utils.h" -#include "spirv-tools/libspirv.hpp" - +namespace spvtools { +namespace opt { namespace { using ::testing::Contains; -using namespace spvtools; -using spvtools::opt::analysis::DefUseManager; - #ifdef SPIRV_EFFCEE -std::string Disassemble(const std::string& original, ir::IRContext* context, +std::string Disassemble(const std::string& original, IRContext* context, uint32_t disassemble_options = 0) { std::vector optimized_bin; context->module()->ToBinary(&optimized_bin, true); @@ -51,7 +52,7 @@ std::string Disassemble(const std::string& original, ir::IRContext* context, return optimized_asm; } -void Match(const std::string& original, ir::IRContext* context, +void Match(const std::string& original, IRContext* context, uint32_t disassemble_options = 0) { std::string disassembly = Disassemble(original, context, disassemble_options); auto match_result = effcee::Match(disassembly, original); @@ -78,15 +79,15 @@ TEST_P(IntegerInstructionFoldingTest, Case) { const auto& tc = GetParam(); // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Fold the instruction to test. - opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); - ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - bool succeeded = opt::FoldInstruction(inst); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); // Make sure the instruction folded as expected. EXPECT_TRUE(succeeded); @@ -94,8 +95,8 @@ TEST_P(IntegerInstructionFoldingTest, Case) { EXPECT_EQ(inst->opcode(), SpvOpCopyObject); inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); EXPECT_EQ(inst->opcode(), SpvOpConstant); - opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr(); - const opt::analysis::IntConstant* result = + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::IntConstant* result = const_mrg->GetConstantFromInst(inst)->AsIntConstant(); EXPECT_NE(result, nullptr); if (result != nullptr) { @@ -117,6 +118,9 @@ TEST_P(IntegerInstructionFoldingTest, Case) { const std::string& Header() { static const std::string header = R"(OpCapability Shader OpCapability Float16 +OpCapability Float64 +OpCapability Int16 +OpCapability Int64 %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %main "main" @@ -126,7 +130,6 @@ OpName %main "main" %void = OpTypeVoid %void_func = OpTypeFunction %void %bool = OpTypeBool -%float16 = OpTypeFloat 16 %float = OpTypeFloat 32 %double = OpTypeFloat 64 %half = OpTypeFloat 16 @@ -143,6 +146,7 @@ OpName %main "main" %v4float = OpTypeVector %float 4 %v4double = OpTypeVector %double 4 %v2float = OpTypeVector %float 2 +%v2double = OpTypeVector %double 2 %v2bool = OpTypeVector %bool 2 %struct_v2int_int_int = OpTypeStruct %v2int %int %int %_ptr_int = OpTypePointer Function %int @@ -158,7 +162,9 @@ OpName %main "main" %_ptr_v4double = OpTypePointer Function %v4double %_ptr_struct_v2int_int_int = OpTypePointer Function %struct_v2int_int_int %_ptr_v2float = OpTypePointer Function %v2float +%_ptr_v2double = OpTypePointer Function %v2double %short_0 = OpConstant %short 0 +%short_2 = OpConstant %short 2 %short_3 = OpConstant %short 3 %100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps. %103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps. @@ -173,12 +179,15 @@ OpName %main "main" %long_2 = OpConstant %long 2 %long_3 = OpConstant %long 3 %uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 %uint_2 = OpConstant %uint 2 %uint_3 = OpConstant %uint 3 %uint_4 = OpConstant %uint 4 %uint_32 = OpConstant %uint 32 %uint_max = OpConstant %uint 4294967295 %v2int_undef = OpUndef %v2int +%v2int_0_0 = OpConstantComposite %v2int %int_0 %int_0 +%v2int_1_0 = OpConstantComposite %v2int %int_1 %int_0 %v2int_2_2 = OpConstantComposite %v2int %int_2 %int_2 %v2int_2_3 = OpConstantComposite %v2int %int_2 %int_3 %v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2 @@ -191,18 +200,16 @@ OpName %main "main" %102 = OpConstantComposite %v2int %103 %103 %v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0 %struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0 -%float16_0 = OpConstant %float16 0 -%float16_1 = OpConstant %float16 1 -%float16_2 = OpConstant %float16 2 %float_n1 = OpConstant %float -1 %104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps. +%float_null = OpConstantNull %float %float_0 = OpConstant %float 0 -%float_half = OpConstant %float 0.5 %float_1 = OpConstant %float 1 %float_2 = OpConstant %float 2 %float_3 = OpConstant %float 3 %float_4 = OpConstant %float 4 %float_0p5 = OpConstant %float 0.5 +%v2float_0_0 = OpConstantComposite %v2float %float_0 %float_0 %v2float_2_2 = OpConstantComposite %v2float %float_2 %float_2 %v2float_2_3 = OpConstantComposite %v2float %float_2 %float_3 %v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2 @@ -211,27 +218,55 @@ OpName %main "main" %v2float_null = OpConstantNull %v2float %double_n1 = OpConstant %double -1 %105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps. +%double_null = OpConstantNull %double %double_0 = OpConstant %double 0 %double_1 = OpConstant %double 1 %double_2 = OpConstant %double 2 %double_3 = OpConstant %double 3 -%float_nan = OpConstant %float -0x1.8p+128 -%double_nan = OpConstant %double -0x1.8p+1024 +%double_4 = OpConstant %double 4 +%double_0p5 = OpConstant %double 0.5 +%v2double_0_0 = OpConstantComposite %v2double %double_0 %double_0 +%v2double_2_2 = OpConstantComposite %v2double %double_2 %double_2 +%v2double_2_3 = OpConstantComposite %v2double %double_2 %double_3 +%v2double_3_2 = OpConstantComposite %v2double %double_3 %double_2 +%v2double_4_4 = OpConstantComposite %v2double %double_4 %double_4 +%v2double_2_0p5 = OpConstantComposite %v2double %double_2 %double_0p5 +%v2double_null = OpConstantNull %v2double %108 = OpConstant %half 0 %half_1 = OpConstant %half 1 %106 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 %v4float_0_0_0_0 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 %v4float_0_0_0_1 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_1 +%v4float_0_1_0_0 = OpConstantComposite %v4float %float_0 %float_1 %float_null %float_0 %v4float_1_1_1_1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 %107 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0 %v4double_0_0_0_0 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0 %v4double_0_0_0_1 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_1 +%v4double_0_1_0_0 = OpConstantComposite %v4double %double_0 %double_1 %double_null %double_0 %v4double_1_1_1_1 = OpConstantComposite %v4double %double_1 %double_1 %double_1 %double_1 +%v4double_1_1_1_0p5 = OpConstantComposite %v4double %double_1 %double_1 %double_1 %double_0p5 +%v4double_null = OpConstantNull %v4double +%v4float_n1_2_1_3 = OpConstantComposite %v4float %float_n1 %float_2 %float_1 %float_3 )"; return header; } +// Returns the header with definitions of float NaN and double NaN. Since FC +// "; CHECK: [[double_n0:%\\w+]] = OpConstant [[double]] -0\n" finds +// %double_nan = OpConstant %double -0x1.8p+1024 instead of +// %double_n0 = OpConstant %double -0, +// we separates those definitions from Header(). +const std::string& HeaderWithNaN() { + static const std::string headerWithNaN = + Header() + + R"(%float_nan = OpConstant %float -0x1.8p+128 +%double_nan = OpConstant %double -0x1.8p+1024 +)"; + + return headerWithNaN; +} + // clang-format off INSTANTIATE_TEST_CASE_P(TestCase, IntegerInstructionFoldingTest, ::testing::Values( @@ -415,29 +450,29 @@ TEST_P(IntVectorInstructionFoldingTest, Case) { const auto& tc = GetParam(); // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Fold the instruction to test. - opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); - ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - bool succeeded = opt::FoldInstruction(inst); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + SpvOp original_opcode = inst->opcode(); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); // Make sure the instruction folded as expected. - EXPECT_TRUE(succeeded); - if (inst != nullptr) { + EXPECT_EQ(succeeded, inst == nullptr || inst->opcode() != original_opcode); + if (succeeded && inst != nullptr) { EXPECT_EQ(inst->opcode(), SpvOpCopyObject); inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); std::vector opcodes = {SpvOpConstantComposite}; EXPECT_THAT(opcodes, Contains(inst->opcode())); - opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr(); - const opt::analysis::Constant* result = - const_mrg->GetConstantFromInst(inst); + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::Constant* result = const_mrg->GetConstantFromInst(inst); EXPECT_NE(result, nullptr); if (result != nullptr) { - const std::vector& componenets = + const std::vector& componenets = result->AsVectorConstant()->GetComponents(); EXPECT_EQ(componenets.size(), tc.expected_result.size()); for (size_t i = 0; i < componenets.size(); i++) { @@ -468,7 +503,25 @@ INSTANTIATE_TEST_CASE_P(TestCase, IntVectorInstructionFoldingTest, "%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 0 3\n" + "OpReturn\n" + "OpFunctionEnd", - 2, {0,3}) + 2, {0,3}), + InstructionFoldingCase>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 4294967295 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, {0,0}), + InstructionFoldingCase>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 0 4294967295 \n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, {0,0}) )); // clang-format on @@ -479,15 +532,15 @@ TEST_P(BooleanInstructionFoldingTest, Case) { const auto& tc = GetParam(); // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Fold the instruction to test. - opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); - ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - bool succeeded = opt::FoldInstruction(inst); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); // Make sure the instruction folded as expected. EXPECT_TRUE(succeeded); @@ -496,8 +549,8 @@ TEST_P(BooleanInstructionFoldingTest, Case) { inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); std::vector bool_opcodes = {SpvOpConstantTrue, SpvOpConstantFalse}; EXPECT_THAT(bool_opcodes, Contains(inst->opcode())); - opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr(); - const opt::analysis::BoolConstant* result = + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::BoolConstant* result = const_mrg->GetConstantFromInst(inst)->AsBoolConstant(); EXPECT_NE(result, nullptr); if (result != nullptr) { @@ -710,6 +763,377 @@ INSTANTIATE_TEST_CASE_P(TestCase, BooleanInstructionFoldingTest, "OpFunctionEnd", 2, true) )); + +INSTANTIATE_TEST_CASE_P(FClampAndCmpLHS, BooleanInstructionFoldingTest, +::testing::Values( + // Test case 0: fold 0.0 > clamp(n, 0.0, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFOrdGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 1: fold 0.0 > clamp(n, -1.0, -1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_n1\n" + + "%2 = OpFOrdGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 2: fold 0.0 >= clamp(n, 1, 2) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFOrdGreaterThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 3: fold 0.0 >= clamp(n, -1.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFOrdGreaterThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 4: fold 0.0 <= clamp(n, 0.0, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFOrdLessThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 5: fold 0.0 <= clamp(n, -1.0, -1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_n1\n" + + "%2 = OpFOrdLessThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 6: fold 0.0 < clamp(n, 1, 2) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFOrdLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 7: fold 0.0 < clamp(n, -1.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFOrdLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 8: fold 0.0 > clamp(n, 0.0, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFUnordGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 9: fold 0.0 > clamp(n, -1.0, -1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_n1\n" + + "%2 = OpFUnordGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 10: fold 0.0 >= clamp(n, 1, 2) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFUnordGreaterThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 11: fold 0.0 >= clamp(n, -1.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordGreaterThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 12: fold 0.0 <= clamp(n, 0.0, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFUnordLessThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 13: fold 0.0 <= clamp(n, -1.0, -1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_n1\n" + + "%2 = OpFUnordLessThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 14: fold 0.0 < clamp(n, 1, 2) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFUnordLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 15: fold 0.0 < clamp(n, -1.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false) +)); + +INSTANTIATE_TEST_CASE_P(FClampAndCmpRHS, BooleanInstructionFoldingTest, +::testing::Values( + // Test case 0: fold clamp(n, 0.0, 1.0) > 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFOrdGreaterThan %bool %clamp %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 1: fold clamp(n, 1.0, 1.0) > 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_1\n" + + "%2 = OpFOrdGreaterThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 2: fold clamp(n, 1, 2) >= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFOrdGreaterThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 3: fold clamp(n, 1.0, 2.0) >= 3.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFOrdGreaterThanEqual %bool %clamp %float_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 4: fold clamp(n, 0.0, 1.0) <= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFOrdLessThanEqual %bool %clamp %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 5: fold clamp(n, 1.0, 2.0) <= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFOrdLessThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 6: fold clamp(n, 1, 2) < 3 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFOrdLessThan %bool %clamp %float_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 7: fold clamp(n, -1.0, 0.0) < -1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFOrdLessThan %bool %clamp %float_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 8: fold clamp(n, 0.0, 1.0) > 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFUnordGreaterThan %bool %clamp %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 9: fold clamp(n, 1.0, 2.0) > 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFUnordGreaterThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 10: fold clamp(n, 1, 2) >= 3.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFUnordGreaterThanEqual %bool %clamp %float_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 11: fold clamp(n, -1.0, 0.0) >= -1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordGreaterThanEqual %bool %clamp %float_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 12: fold clamp(n, 0.0, 1.0) <= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFUnordLessThanEqual %bool %clamp %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 13: fold clamp(n, 1.0, 1.0) <= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_1\n" + + "%2 = OpFUnordLessThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 14: fold clamp(n, 1, 2) < 3 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFUnordLessThan %bool %clamp %float_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 15: fold clamp(n, -1.0, 0.0) < -1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordLessThan %bool %clamp %float_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 16: fold clamp(n, -1.0, 0.0) < -1.0 (one test for double) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%ld = OpLoad %double %n\n" + + "%clamp = OpExtInst %double %1 FClamp %ld %double_n1 %double_0\n" + + "%2 = OpFUnordLessThan %bool %clamp %double_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false) +)); // clang-format on using FloatInstructionFoldingTest = @@ -719,15 +1143,15 @@ TEST_P(FloatInstructionFoldingTest, Case) { const auto& tc = GetParam(); // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Fold the instruction to test. - opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); - ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - bool succeeded = opt::FoldInstruction(inst); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); // Make sure the instruction folded as expected. EXPECT_TRUE(succeeded); @@ -735,8 +1159,8 @@ TEST_P(FloatInstructionFoldingTest, Case) { EXPECT_EQ(inst->opcode(), SpvOpCopyObject); inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); EXPECT_EQ(inst->opcode(), SpvOpConstant); - opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr(); - const opt::analysis::FloatConstant* result = + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::FloatConstant* result = const_mrg->GetConstantFromInst(inst)->AsFloatConstant(); EXPECT_NE(result, nullptr); if (result != nullptr) { @@ -799,7 +1223,63 @@ INSTANTIATE_TEST_CASE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest, "%2 = OpFDiv %float %float_n1 %float_0\n" + "OpReturn\n" + "OpFunctionEnd", - 2, -std::numeric_limits::infinity()) + 2, -std::numeric_limits::infinity()), + // Test case 6: Fold (2.0, 3.0) dot (2.0, 0.5) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpDot %float %v2float_2_3 %v2float_2_0p5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 5.5f), + // Test case 7: Fold (0.0, 0.0) dot v + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %v\n" + + "%3 = OpDot %float %v2float_0_0 %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 8: Fold v dot (0.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %v\n" + + "%3 = OpDot %float %2 %v2float_0_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 9: Fold Null dot v + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %v\n" + + "%3 = OpDot %float %v2float_null %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 10: Fold v dot Null + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %v\n" + + "%3 = OpDot %float %2 %v2float_null\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 11: Fold -2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFNegate %float %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -2) )); // clang-format on @@ -810,15 +1290,15 @@ TEST_P(DoubleInstructionFoldingTest, Case) { const auto& tc = GetParam(); // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Fold the instruction to test. - opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); - ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - bool succeeded = opt::FoldInstruction(inst); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); // Make sure the instruction folded as expected. EXPECT_TRUE(succeeded); @@ -826,8 +1306,8 @@ TEST_P(DoubleInstructionFoldingTest, Case) { EXPECT_EQ(inst->opcode(), SpvOpCopyObject); inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); EXPECT_EQ(inst->opcode(), SpvOpConstant); - opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr(); - const opt::analysis::FloatConstant* result = + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::FloatConstant* result = const_mrg->GetConstantFromInst(inst)->AsFloatConstant(); EXPECT_NE(result, nullptr); if (result != nullptr) { @@ -879,14 +1359,70 @@ INSTANTIATE_TEST_CASE_P(DoubleConstantFoldingTest, DoubleInstructionFoldingTest, "OpReturn\n" + "OpFunctionEnd", 2, std::numeric_limits::infinity()), - // Test case 4: Fold -1.0 / 0.0 + // Test case 5: Fold -1.0 / 0.0 InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%2 = OpFDiv %double %double_n1 %double_0\n" + "OpReturn\n" + "OpFunctionEnd", - 2, -std::numeric_limits::infinity()) + 2, -std::numeric_limits::infinity()), + // Test case 6: Fold (2.0, 3.0) dot (2.0, 0.5) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpDot %double %v2double_2_3 %v2double_2_0p5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 5.5f), + // Test case 7: Fold (0.0, 0.0) dot v + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2double Function\n" + + "%2 = OpLoad %v2double %v\n" + + "%3 = OpDot %double %v2double_0_0 %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 8: Fold v dot (0.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2double Function\n" + + "%2 = OpLoad %v2double %v\n" + + "%3 = OpDot %double %2 %v2double_0_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 9: Fold Null dot v + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2double Function\n" + + "%2 = OpLoad %v2double %v\n" + + "%3 = OpDot %double %v2double_null %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 10: Fold v dot Null + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2double Function\n" + + "%2 = OpLoad %v2double %v\n" + + "%3 = OpDot %double %2 %v2double_null\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 11: Fold -2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFNegate %double %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -2) )); // clang-format on @@ -1423,7 +1959,7 @@ INSTANTIATE_TEST_CASE_P(DoubleNaNCompareConstantFoldingTest, BooleanInstructionF ::testing::Values( // Test case 0: fold NaN == 0 (ord) InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%2 = OpFOrdEqual %bool %double_nan %double_0\n" + "OpReturn\n" + @@ -1431,7 +1967,7 @@ INSTANTIATE_TEST_CASE_P(DoubleNaNCompareConstantFoldingTest, BooleanInstructionF 2, false), // Test case 1: fold NaN == NaN (unord) InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%2 = OpFUnordEqual %bool %double_nan %double_0\n" + "OpReturn\n" + @@ -1439,7 +1975,7 @@ INSTANTIATE_TEST_CASE_P(DoubleNaNCompareConstantFoldingTest, BooleanInstructionF 2, true), // Test case 2: fold NaN != NaN (ord) InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%2 = OpFOrdNotEqual %bool %double_nan %double_0\n" + "OpReturn\n" + @@ -1447,7 +1983,7 @@ INSTANTIATE_TEST_CASE_P(DoubleNaNCompareConstantFoldingTest, BooleanInstructionF 2, false), // Test case 3: fold NaN != NaN (unord) InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%2 = OpFUnordNotEqual %bool %double_nan %double_0\n" + "OpReturn\n" + @@ -1459,7 +1995,7 @@ INSTANTIATE_TEST_CASE_P(FloatNaNCompareConstantFoldingTest, BooleanInstructionFo ::testing::Values( // Test case 0: fold NaN == 0 (ord) InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%2 = OpFOrdEqual %bool %float_nan %float_0\n" + "OpReturn\n" + @@ -1467,7 +2003,7 @@ INSTANTIATE_TEST_CASE_P(FloatNaNCompareConstantFoldingTest, BooleanInstructionFo 2, false), // Test case 1: fold NaN == NaN (unord) InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%2 = OpFUnordEqual %bool %float_nan %float_0\n" + "OpReturn\n" + @@ -1475,7 +2011,7 @@ INSTANTIATE_TEST_CASE_P(FloatNaNCompareConstantFoldingTest, BooleanInstructionFo 2, true), // Test case 2: fold NaN != NaN (ord) InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%2 = OpFOrdNotEqual %bool %float_nan %float_0\n" + "OpReturn\n" + @@ -1483,7 +2019,7 @@ INSTANTIATE_TEST_CASE_P(FloatNaNCompareConstantFoldingTest, BooleanInstructionFo 2, false), // Test case 3: fold NaN != NaN (unord) InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%2 = OpFUnordNotEqual %bool %float_nan %float_0\n" + "OpReturn\n" + @@ -1512,22 +2048,23 @@ TEST_P(IntegerInstructionFoldingTestWithMap, Case) { const auto& tc = GetParam(); // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Fold the instruction to test. - opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); - ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - inst = opt::FoldInstructionToConstant(inst, tc.id_map); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + inst = context->get_instruction_folder().FoldInstructionToConstant(inst, + tc.id_map); // Make sure the instruction folded as expected. EXPECT_NE(inst, nullptr); if (inst != nullptr) { EXPECT_EQ(inst->opcode(), SpvOpConstant); - opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr(); - const opt::analysis::IntConstant* result = + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::IntConstant* result = const_mrg->GetConstantFromInst(inst)->AsIntConstant(); EXPECT_NE(result, nullptr); if (result != nullptr) { @@ -1559,23 +2096,24 @@ TEST_P(BooleanInstructionFoldingTestWithMap, Case) { const auto& tc = GetParam(); // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Fold the instruction to test. - opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); - ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - inst = opt::FoldInstructionToConstant(inst, tc.id_map); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + inst = context->get_instruction_folder().FoldInstructionToConstant(inst, + tc.id_map); // Make sure the instruction folded as expected. EXPECT_NE(inst, nullptr); if (inst != nullptr) { std::vector bool_opcodes = {SpvOpConstantTrue, SpvOpConstantFalse}; EXPECT_THAT(bool_opcodes, Contains(inst->opcode())); - opt::analysis::ConstantManager* const_mrg = context->get_constant_mgr(); - const opt::analysis::BoolConstant* result = + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::BoolConstant* result = const_mrg->GetConstantFromInst(inst)->AsBoolConstant(); EXPECT_NE(result, nullptr); if (result != nullptr) { @@ -1608,16 +2146,16 @@ TEST_P(GeneralInstructionFoldingTest, Case) { const auto& tc = GetParam(); // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Fold the instruction to test. - opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); - ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - std::unique_ptr original_inst(inst->Clone(context.get())); - bool succeeded = opt::FoldInstruction(inst); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + std::unique_ptr original_inst(inst->Clone(context.get())); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); // Make sure the instruction folded as expected. EXPECT_EQ(inst->result_id(), original_inst->result_id()); @@ -2057,19 +2595,19 @@ INSTANTIATE_TEST_CASE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTes "OpReturn\n" + "OpFunctionEnd", 2, 0), - // Test case 38: Don't fold 0 + 3 (long), bad length + // Test case 38: Don't fold 2 + 3 (long), bad length InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + - "%2 = OpIAdd %long %long_0 %long_3\n" + + "%2 = OpIAdd %long %long_2 %long_3\n" + "OpReturn\n" + "OpFunctionEnd", 2, 0), - // Test case 39: Don't fold 0 + 3 (short), bad length + // Test case 39: Don't fold 2 + 3 (short), bad length InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + - "%2 = OpIAdd %short %short_0 %short_3\n" + + "%2 = OpIAdd %short %short_2 %short_3\n" + "OpReturn\n" + "OpFunctionEnd", 2, 0), @@ -2794,6 +3332,362 @@ INSTANTIATE_TEST_CASE_P(DoubleVectorRedundantFoldingTest, GeneralInstructionFold 2, 3) )); +INSTANTIATE_TEST_CASE_P(IntegerRedundantFoldingTest, GeneralInstructionFoldingTest, + ::testing::Values( + // Test case 0: Don't fold n + 1 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %3 %uint_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Don't fold 1 + n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %uint_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 2: Fold n + 0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %3 %uint_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 3: Fold 0 + n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %uint_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 4: Don't fold n + (1,0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%3 = OpLoad %v2int %n\n" + + "%2 = OpIAdd %v2int %3 %v2int_1_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 5: Don't fold (1,0) + n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%3 = OpLoad %v2int %n\n" + + "%2 = OpIAdd %v2int %v2int_1_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 6: Fold n + (0,0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%3 = OpLoad %v2int %n\n" + + "%2 = OpIAdd %v2int %3 %v2int_0_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 7: Fold (0,0) + n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%3 = OpLoad %v2int %n\n" + + "%2 = OpIAdd %v2int %v2int_0_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3) +)); + +INSTANTIATE_TEST_CASE_P(ClampAndCmpLHS, GeneralInstructionFoldingTest, +::testing::Values( + // Test case 0: Don't Fold 0.0 < clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Don't Fold 0.0 < clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 2: Don't Fold 0.0 <= clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordLessThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 3: Don't Fold 0.0 <= clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdLessThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 4: Don't Fold 0.0 > clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 5: Don't Fold 0.0 > clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 6: Don't Fold 0.0 >= clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordGreaterThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 7: Don't Fold 0.0 >= clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdGreaterThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 8: Don't Fold 0.0 < clamp(0, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFUnordLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 9: Don't Fold 0.0 < clamp(0, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFOrdLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 10: Don't Fold 0.0 > clamp(-1, 0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 11: Don't Fold 0.0 > clamp(-1, 0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFOrdGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0) +)); + +INSTANTIATE_TEST_CASE_P(ClampAndCmpRHS, GeneralInstructionFoldingTest, +::testing::Values( + // Test case 0: Don't Fold clamp(-1, 1) < 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordLessThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Don't Fold clamp(-1, 1) < 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdLessThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 2: Don't Fold clamp(-1, 1) <= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordLessThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 3: Don't Fold clamp(-1, 1) <= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdLessThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 4: Don't Fold clamp(-1, 1) > 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordGreaterThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 5: Don't Fold clamp(-1, 1) > 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdGreaterThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 6: Don't Fold clamp(-1, 1) >= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordGreaterThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 7: Don't Fold clamp(-1, 1) >= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdGreaterThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 8: Don't Fold clamp(-1, 0) < 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordLessThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 9: Don't Fold clamp(0, 1) < 1 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFOrdLessThan %bool %clamp %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 10: Don't Fold clamp(-1, 0) > -1 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordGreaterThan %bool %clamp %float_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 11: Don't Fold clamp(-1, 0) > -1 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFOrdGreaterThan %bool %clamp %float_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0) +)); + INSTANTIATE_TEST_CASE_P(FToIConstantFoldingTest, IntegerInstructionFoldingTest, ::testing::Values( // Test case 0: Fold int(3.0) @@ -2842,16 +3736,16 @@ TEST_P(ToNegateFoldingTest, Case) { const auto& tc = GetParam(); // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Fold the instruction to test. - opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); - ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - std::unique_ptr original_inst(inst->Clone(context.get())); - bool succeeded = opt::FoldInstruction(inst); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + std::unique_ptr original_inst(inst->Clone(context.get())); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); // Make sure the instruction folded as expected. EXPECT_EQ(inst->result_id(), original_inst->result_id()); @@ -2965,22 +3859,52 @@ TEST_P(MatchingInstructionFoldingTest, Case) { const auto& tc = GetParam(); // Build module. - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(nullptr, context); // Fold the instruction to test. - opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); - ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); - std::unique_ptr original_inst(inst->Clone(context.get())); - bool succeeded = opt::FoldInstruction(inst); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + std::unique_ptr original_inst(inst->Clone(context.get())); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); EXPECT_EQ(succeeded, tc.expected_result); if (succeeded) { Match(tc.test_body, context.get()); } } +INSTANTIATE_TEST_CASE_P(RedundantIntegerMatching, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: Fold 0 + n (change sign) + InstructionFoldingCase( + Header() + + "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: %2 = OpBitcast [[uint]] %3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %int_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 2, true), + // Test case 0: Fold 0 + n (change sign) + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: %2 = OpBitcast [[int]] %3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%3 = OpLoad %int %n\n" + + "%2 = OpIAdd %int %uint_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 2, true) +)); + INSTANTIATE_TEST_CASE_P(MergeNegateTest, MatchingInstructionFoldingTest, ::testing::Values( // Test case 0: fold consecutive fnegate @@ -3284,7 +4208,38 @@ INSTANTIATE_TEST_CASE_P(MergeNegateTest, MatchingInstructionFoldingTest, "%4 = OpSNegate %long %3\n" + "OpReturn\n" + "OpFunctionEnd", - 4, true) + 4, true), + // Test case 18: fold -vec4(-1.0, 2.0, 1.0, 3.0) + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[v4float:%\\w+]] = OpTypeVector [[float]] 4\n" + + "; CHECK: [[float_n1:%\\w+]] = OpConstant [[float]] -1\n" + + "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[float_n3:%\\w+]] = OpConstant [[float]] -3\n" + + "; CHECK: [[v4float_1_n2_n1_n3:%\\w+]] = OpConstantComposite [[v4float]] [[float_1]] [[float_n2]] [[float_n1]] [[float_n3]]\n" + + "; CHECK: %2 = OpCopyObject [[v4float]] [[v4float_1_n2_n1_n3]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFNegate %v4float %v4float_n1_2_1_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 19: fold vector fnegate with null + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v2double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: [[double_n0:%\\w+]] = OpConstant [[double]] -0\n" + + "; CHECK: [[v2double_0_0:%\\w+]] = OpConstantComposite [[v2double]] [[double_n0]] [[double_n0]]\n" + + "; CHECK: %2 = OpCopyObject [[v2double]] [[v2double_0_0]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFNegate %v2double %v2double_null\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true) )); INSTANTIATE_TEST_CASE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest, @@ -3665,7 +4620,123 @@ INSTANTIATE_TEST_CASE_P(MergeMulTest, MatchingInstructionFoldingTest, "%4 = OpIMul %v2int %v2int_2_2 %3\n" + "OpReturn\n" + "OpFunctionEnd\n", - 4, true) + 4, true), + // Test case 18: Fold OpVectorTimesScalar + // {4,4} = OpVectorTimesScalar v2float {2,2} 2 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[v2float:%\\w+]] = OpTypeVector [[float]] 2\n" + + "; CHECK: [[float_4:%\\w+]] = OpConstant [[float]] 4\n" + + "; CHECK: [[v2float_4_4:%\\w+]] = OpConstantComposite [[v2float]] [[float_4]] [[float_4]]\n" + + "; CHECK: %2 = OpCopyObject [[v2float]] [[v2float_4_4]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVectorTimesScalar %v2float %v2float_2_2 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 19: Fold OpVectorTimesScalar + // {0,0} = OpVectorTimesScalar v2float v2float_null -1 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[v2float:%\\w+]] = OpTypeVector [[float]] 2\n" + + "; CHECK: [[v2float_null:%\\w+]] = OpConstantNull [[v2float]]\n" + + "; CHECK: %2 = OpCopyObject [[v2float]] [[v2float_null]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVectorTimesScalar %v2float %v2float_null %float_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 20: Fold OpVectorTimesScalar + // {4,4} = OpVectorTimesScalar v2double {2,2} 2 + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v2double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: [[double_4:%\\w+]] = OpConstant [[double]] 4\n" + + "; CHECK: [[v2double_4_4:%\\w+]] = OpConstantComposite [[v2double]] [[double_4]] [[double_4]]\n" + + "; CHECK: %2 = OpCopyObject [[v2double]] [[v2double_4_4]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVectorTimesScalar %v2double %v2double_2_2 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 21: Fold OpVectorTimesScalar + // {0,0} = OpVectorTimesScalar v2double {0,0} n + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v2double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: {{%\\w+}} = OpConstant [[double]] 0\n" + + "; CHECK: [[double_0:%\\w+]] = OpConstant [[double]] 0\n" + + "; CHECK: [[v2double_0_0:%\\w+]] = OpConstantComposite [[v2double]] [[double_0]] [[double_0]]\n" + + "; CHECK: %2 = OpCopyObject [[v2double]] [[v2double_0_0]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%load = OpLoad %double %n\n" + + "%2 = OpVectorTimesScalar %v2double %v2double_0_0 %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 22: Fold OpVectorTimesScalar + // {0,0} = OpVectorTimesScalar v2double n 0 + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v2double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: [[v2double_null:%\\w+]] = OpConstantNull [[v2double]]\n" + + "; CHECK: %2 = OpCopyObject [[v2double]] [[v2double_null]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2double Function\n" + + "%load = OpLoad %v2double %n\n" + + "%2 = OpVectorTimesScalar %v2double %load %double_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 23: merge fmul of fdiv + // x * (y / x) = y + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[ldx:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: [[ldy:%\\w+]] = OpLoad [[float]] [[y:%\\w+]]\n" + + "; CHECK: %5 = OpCopyObject [[float]] [[ldy]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%x = OpVariable %_ptr_float Function\n" + + "%y = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %x\n" + + "%3 = OpLoad %float %y\n" + + "%4 = OpFDiv %float %3 %2\n" + + "%5 = OpFMul %float %2 %4\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 5, true), + // Test case 24: merge fmul of fdiv + // (y / x) * x = y + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[ldx:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: [[ldy:%\\w+]] = OpLoad [[float]] [[y:%\\w+]]\n" + + "; CHECK: %5 = OpCopyObject [[float]] [[ldy]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%x = OpVariable %_ptr_float Function\n" + + "%y = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %x\n" + + "%3 = OpLoad %float %y\n" + + "%4 = OpFDiv %float %3 %2\n" + + "%5 = OpFMul %float %4 %2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 5, true) )); INSTANTIATE_TEST_CASE_P(MergeDivTest, MatchingInstructionFoldingTest, @@ -3873,7 +4944,45 @@ INSTANTIATE_TEST_CASE_P(MergeDivTest, MatchingInstructionFoldingTest, "%4 = OpFDiv %float %3 %v2float_null\n" + "OpReturn\n" + "OpFunctionEnd\n", - 4, false) + 4, false), + // Test case 14: merge fmul of fdiv + // (y * x) / x = y + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[ldx:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: [[ldy:%\\w+]] = OpLoad [[float]] [[y:%\\w+]]\n" + + "; CHECK: %5 = OpCopyObject [[float]] [[ldy]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%x = OpVariable %_ptr_float Function\n" + + "%y = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %x\n" + + "%3 = OpLoad %float %y\n" + + "%4 = OpFMul %float %3 %2\n" + + "%5 = OpFDiv %float %4 %2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 5, true), + // Test case 15: merge fmul of fdiv + // (x * y) / x = y + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[ldx:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: [[ldy:%\\w+]] = OpLoad [[float]] [[y:%\\w+]]\n" + + "; CHECK: %5 = OpCopyObject [[float]] [[ldy]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%x = OpVariable %_ptr_float Function\n" + + "%y = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %x\n" + + "%3 = OpLoad %float %y\n" + + "%4 = OpFMul %float %2 %3\n" + + "%5 = OpFDiv %float %4 %2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 5, true) )); INSTANTIATE_TEST_CASE_P(MergeAddTest, MatchingInstructionFoldingTest, @@ -4464,7 +5573,510 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractMatchingTest, MatchingInstructionFolding "%4 = OpCompositeExtract %int %3 1\n" + "OpReturn\n" + "OpFunctionEnd", + 4, true), + // Test case 3: Using fmix feeding extract with a 1 in the a position. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 4\n" + + "; CHECK: [[ptr_v4double:%\\w+]] = OpTypePointer Function [[v4double]]\n" + + "; CHECK: [[m:%\\w+]] = OpVariable [[ptr_v4double]] Function\n" + + "; CHECK: [[n:%\\w+]] = OpVariable [[ptr_v4double]] Function\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[v4double]] [[n]]\n" + + "; CHECK: %5 = OpCompositeExtract [[double]] [[ld]] 1\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%m = OpVariable %_ptr_v4double Function\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %m\n" + + "%3 = OpLoad %v4double %n\n" + + "%4 = OpExtInst %v4double %1 FMix %2 %3 %v4double_0_1_0_0\n" + + "%5 = OpCompositeExtract %double %4 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, true), + // Test case 4: Using fmix feeding extract with a 0 in the a position. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 4\n" + + "; CHECK: [[ptr_v4double:%\\w+]] = OpTypePointer Function [[v4double]]\n" + + "; CHECK: [[m:%\\w+]] = OpVariable [[ptr_v4double]] Function\n" + + "; CHECK: [[n:%\\w+]] = OpVariable [[ptr_v4double]] Function\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[v4double]] [[m]]\n" + + "; CHECK: %5 = OpCompositeExtract [[double]] [[ld]] 2\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%m = OpVariable %_ptr_v4double Function\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %m\n" + + "%3 = OpLoad %v4double %n\n" + + "%4 = OpExtInst %v4double %1 FMix %2 %3 %v4double_0_1_0_0\n" + + "%5 = OpCompositeExtract %double %4 2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, true), + // Test case 5: Using fmix feeding extract with a null for the alpha + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 4\n" + + "; CHECK: [[ptr_v4double:%\\w+]] = OpTypePointer Function [[v4double]]\n" + + "; CHECK: [[m:%\\w+]] = OpVariable [[ptr_v4double]] Function\n" + + "; CHECK: [[n:%\\w+]] = OpVariable [[ptr_v4double]] Function\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[v4double]] [[m]]\n" + + "; CHECK: %5 = OpCompositeExtract [[double]] [[ld]] 0\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%m = OpVariable %_ptr_v4double Function\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %m\n" + + "%3 = OpLoad %v4double %n\n" + + "%4 = OpExtInst %v4double %1 FMix %2 %3 %v4double_null\n" + + "%5 = OpCompositeExtract %double %4 0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, true), + // Test case 6: Don't fold: Using fmix feeding extract with 0.5 in the a + // position. + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%m = OpVariable %_ptr_v4double Function\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %m\n" + + "%3 = OpLoad %v4double %n\n" + + "%4 = OpExtInst %v4double %1 FMix %2 %3 %v4double_1_1_1_0p5\n" + + "%5 = OpCompositeExtract %double %4 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, false), + // Test case 7: Extracting the undefined literal value from a vector + // shuffle. + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: %4 = OpUndef [[int]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4int Function\n" + + "%2 = OpLoad %v4int %n\n" + + "%3 = OpVectorShuffle %v2int %2 %2 2 4294967295\n" + + "%4 = OpCompositeExtract %int %3 1\n" + + "OpReturn\n" + + "OpFunctionEnd", 4, true) )); + +INSTANTIATE_TEST_CASE_P(DotProductMatchingTest, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: Using OpDot to extract last element. + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: %3 = OpCompositeExtract [[float]] %2 3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4float Function\n" + + "%2 = OpLoad %v4float %n\n" + + "%3 = OpDot %float %2 %v4float_0_0_0_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, true), + // Test case 1: Using OpDot to extract last element. + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: %3 = OpCompositeExtract [[float]] %2 3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4float Function\n" + + "%2 = OpLoad %v4float %n\n" + + "%3 = OpDot %float %v4float_0_0_0_1 %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, true), + // Test case 2: Using OpDot to extract second element. + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: %3 = OpCompositeExtract [[float]] %2 1\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4float Function\n" + + "%2 = OpLoad %v4float %n\n" + + "%3 = OpDot %float %v4float_0_1_0_0 %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, true), + // Test case 3: Using OpDot to extract last element. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: %3 = OpCompositeExtract [[double]] %2 3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %n\n" + + "%3 = OpDot %double %2 %v4double_0_0_0_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, true), + // Test case 4: Using OpDot to extract last element. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: %3 = OpCompositeExtract [[double]] %2 3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %n\n" + + "%3 = OpDot %double %v4double_0_0_0_1 %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, true), + // Test case 5: Using OpDot to extract second element. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: %3 = OpCompositeExtract [[double]] %2 1\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %n\n" + + "%3 = OpDot %double %v4double_0_1_0_0 %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, true) +)); + +using MatchingInstructionWithNoResultFoldingTest = +::testing::TestWithParam>; + +// Test folding instructions that do not have a result. The instruction +// that will be folded is the last instruction before the return. If there +// are multiple returns, there is not guarentee which one is used. +TEST_P(MatchingInstructionWithNoResultFoldingTest, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + Instruction* inst = nullptr; + Function* func = &*context->module()->begin(); + for (auto& bb : *func) { + Instruction* terminator = bb.terminator(); + if (terminator->IsReturnOrAbort()) { + inst = terminator->PreviousNode(); + break; + } + } + assert(inst && "Invalid test. Could not find instruction to fold."); + std::unique_ptr original_inst(inst->Clone(context.get())); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); + EXPECT_EQ(succeeded, tc.expected_result); + if (succeeded) { + Match(tc.test_body, context.get()); + } +} + +INSTANTIATE_TEST_CASE_P(StoreMatchingTest, MatchingInstructionWithNoResultFoldingTest, +::testing::Values( + // Test case 0: Using OpDot to extract last element. + InstructionFoldingCase( + Header() + + "; CHECK: OpLabel\n" + + "; CHECK-NOT: OpStore\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%undef = OpUndef %v4double\n" + + "OpStore %n %undef\n" + + "OpReturn\n" + + "OpFunctionEnd", + 0 /* OpStore */, true) +)); + +INSTANTIATE_TEST_CASE_P(VectorShuffleMatchingTest, MatchingInstructionWithNoResultFoldingTest, +::testing::Values( + // Test case 0: Basic test 1 + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %7 %5 2 3 6 7\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %7 %8 2 3 4 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 1: Basic test 2 + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %6 %7 0 1 4 5\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %8 %7 2 3 4 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 2: Basic test 3 + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %5 %7 3 2 4 5\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %8 %7 1 0 4 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 3: Basic test 4 + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %7 %6 2 3 5 4\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %7 %8 2 3 7 6\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 4: Don't fold, need both operands of the feeder. + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %7 %8 2 3 7 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, false), + // Test case 5: Don't fold, need both operands of the feeder. + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %8 %7 2 0 7 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, false), + // Test case 6: Fold, need both operands of the feeder, but they are the same. + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %5 %7 0 2 7 5\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %5 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %8 %7 2 0 7 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 7: Fold, need both operands of the feeder, but they are the same. + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %7 %5 2 0 5 7\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %5 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %7 %8 2 0 7 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 8: Replace first operand with a smaller vector. + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %5 %7 0 0 5 3\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v2double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v2double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %5 0 1 2 3\n" + + "%9 = OpVectorShuffle %v4double %8 %7 2 0 7 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 9: Replace first operand with a larger vector. + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %5 %7 3 0 7 5\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v2double %5 %5 0 3\n" + + "%9 = OpVectorShuffle %v4double %8 %7 1 0 5 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 10: Replace unused operand with null. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: [[null:%\\w+]] = OpConstantNull [[v4double]]\n" + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} [[null]] %7 4 2 5 3\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v2double %5 %5 0 3\n" + + "%9 = OpVectorShuffle %v4double %8 %7 4 2 5 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 11: Replace unused operand with null. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: [[null:%\\w+]] = OpConstantNull [[v4double]]\n" + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} [[null]] %5 2 2 5 5\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%8 = OpVectorShuffle %v2double %5 %5 0 3\n" + + "%9 = OpVectorShuffle %v4double %8 %8 2 2 3 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 12: Replace unused operand with null. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: [[null:%\\w+]] = OpConstantNull [[v4double]]\n" + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %7 [[null]] 2 0 1 3\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v2double %5 %5 0 3\n" + + "%9 = OpVectorShuffle %v4double %7 %8 2 0 1 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 13: Shuffle with undef literal. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %7 {{%\\w+}} 2 0 1 4294967295\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v2double %5 %5 0 1\n" + + "%9 = OpVectorShuffle %v4double %7 %8 2 0 1 4294967295\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true) +)); #endif -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/freeze_spec_const_test.cpp b/3rdparty/spirv-tools/test/opt/freeze_spec_const_test.cpp index e91730077..5cc7843b1 100644 --- a/3rdparty/spirv-tools/test/opt/freeze_spec_const_test.cpp +++ b/3rdparty/spirv-tools/test/opt/freeze_spec_const_test.cpp @@ -12,16 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" - #include +#include #include +#include #include -namespace { +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" -using namespace spvtools; +namespace spvtools { +namespace opt { +namespace { struct FreezeSpecConstantValueTypeTestCase { const char* type_decl; @@ -40,7 +42,7 @@ TEST_P(FreezeSpecConstantValueTypeTest, PrimaryType) { std::vector expected = { "OpCapability Shader", "OpMemoryModel Logical GLSL450", test_case.type_decl, test_case.expected_frozen_const}; - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(text), JoinAllInsts(expected), /* skip_nop = */ false); } @@ -121,8 +123,11 @@ TEST_F(FreezeSpecConstantValueRemoveDecorationTest, << "replace_str:\n" << p.second << "\n"; } - SinglePassRunAndCheck( - JoinAllInsts(text), expected_disassembly, - /* skip_nop = */ true); + SinglePassRunAndCheck(JoinAllInsts(text), + expected_disassembly, + /* skip_nop = */ true); } -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/function_utils.h b/3rdparty/spirv-tools/test/opt/function_utils.h index a392e319b..803cacdd5 100644 --- a/3rdparty/spirv-tools/test/opt/function_utils.h +++ b/3rdparty/spirv-tools/test/opt/function_utils.h @@ -12,16 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_TEST_OPT_FUNCTION_UTILS_H_ -#define LIBSPIRV_TEST_OPT_FUNCTION_UTILS_H_ +#ifndef TEST_OPT_FUNCTION_UTILS_H_ +#define TEST_OPT_FUNCTION_UTILS_H_ -#include "opt/function.h" -#include "opt/module.h" +#include "source/opt/function.h" +#include "source/opt/module.h" namespace spvtest { -spvtools::ir::Function* GetFunction(spvtools::ir::Module* module, uint32_t id) { - for (spvtools::ir::Function& f : *module) { +inline spvtools::opt::Function* GetFunction(spvtools::opt::Module* module, + uint32_t id) { + for (spvtools::opt::Function& f : *module) { if (f.result_id() == id) { return &f; } @@ -29,9 +30,9 @@ spvtools::ir::Function* GetFunction(spvtools::ir::Module* module, uint32_t id) { return nullptr; } -const spvtools::ir::Function* GetFunction(const spvtools::ir::Module* module, - uint32_t id) { - for (const spvtools::ir::Function& f : *module) { +inline const spvtools::opt::Function* GetFunction( + const spvtools::opt::Module* module, uint32_t id) { + for (const spvtools::opt::Function& f : *module) { if (f.result_id() == id) { return &f; } @@ -39,9 +40,9 @@ const spvtools::ir::Function* GetFunction(const spvtools::ir::Module* module, return nullptr; } -const spvtools::ir::BasicBlock* GetBasicBlock(const spvtools::ir::Function* fn, - uint32_t id) { - for (const spvtools::ir::BasicBlock& bb : *fn) { +inline const spvtools::opt::BasicBlock* GetBasicBlock( + const spvtools::opt::Function* fn, uint32_t id) { + for (const spvtools::opt::BasicBlock& bb : *fn) { if (bb.id() == id) { return &bb; } @@ -51,4 +52,4 @@ const spvtools::ir::BasicBlock* GetBasicBlock(const spvtools::ir::Function* fn, } // namespace spvtest -#endif // LIBSPIRV_TEST_OPT_FUNCTION_UTILS_H_ +#endif // TEST_OPT_FUNCTION_UTILS_H_ diff --git a/3rdparty/spirv-tools/test/opt/if_conversion_test.cpp b/3rdparty/spirv-tools/test/opt/if_conversion_test.cpp index b31147101..a62a15e88 100644 --- a/3rdparty/spirv-tools/test/opt/if_conversion_test.cpp +++ b/3rdparty/spirv-tools/test/opt/if_conversion_test.cpp @@ -12,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "assembly_builder.h" +#include + #include "gmock/gmock.h" -#include "pass_fixture.h" -#include "pass_utils.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using IfConversionTest = PassTest<::testing::Test>; #ifdef SPIRV_EFFCEE @@ -58,7 +60,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(IfConversionTest, TestSimpleHalfIfTrue) { @@ -93,7 +95,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(IfConversionTest, TestSimpleHalfIfExtraBlock) { @@ -130,7 +132,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(IfConversionTest, TestSimpleHalfIfFalse) { @@ -165,7 +167,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(IfConversionTest, TestVectorSplat) { @@ -207,7 +209,98 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); +} + +TEST_F(IfConversionTest, CodeMotionSameValue) { + const std::string text = R"( +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK-NOT: OpLabel +; CHECK: [[add:%\w+]] = OpIAdd %uint %uint_0 %uint_1 +; CHECK: OpSelectionMerge [[merge_lab:%\w+]] None +; CHECK-NEXT: OpBranchConditional +; CHECK: [[merge_lab]] = OpLabel +; CHECK-NOT: OpLabel +; CHECK: OpStore [[var]] [[add]] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "func" %2 + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint + %2 = OpVariable %_ptr_Output_uint Output + %8 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %1 = OpFunction %void None %8 + %11 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %true %13 %15 + %13 = OpLabel + %14 = OpIAdd %uint %uint_0 %uint_1 + OpBranch %12 + %15 = OpLabel + %16 = OpIAdd %uint %uint_0 %uint_1 + OpBranch %12 + %12 = OpLabel + %17 = OpPhi %uint %16 %15 %14 %13 + OpStore %2 %17 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(IfConversionTest, CodeMotionMultipleInstructions) { + const std::string text = R"( +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK-NOT: OpLabel +; CHECK: [[a1:%\w+]] = OpIAdd %uint %uint_0 %uint_1 +; CHECK: [[a2:%\w+]] = OpIAdd %uint [[a1]] %uint_1 +; CHECK: OpSelectionMerge [[merge_lab:%\w+]] None +; CHECK-NEXT: OpBranchConditional +; CHECK: [[merge_lab]] = OpLabel +; CHECK-NOT: OpLabel +; CHECK: OpStore [[var]] [[a2]] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "func" %2 + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint + %2 = OpVariable %_ptr_Output_uint Output + %8 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %1 = OpFunction %void None %8 + %11 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %true %13 %15 + %13 = OpLabel + %a1 = OpIAdd %uint %uint_0 %uint_1 + %a2 = OpIAdd %uint %a1 %uint_1 + OpBranch %12 + %15 = OpLabel + %b1 = OpIAdd %uint %uint_0 %uint_1 + %b2 = OpIAdd %uint %b1 %uint_1 + OpBranch %12 + %12 = OpLabel + %17 = OpPhi %uint %b2 %15 %a2 %13 + OpStore %2 %17 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); } #endif // SPIRV_EFFCEE @@ -234,7 +327,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(text, text, true, true); + SinglePassRunAndCheck(text, text, true, true); } TEST_F(IfConversionTest, LoopUntouched) { @@ -263,7 +356,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(text, text, true, true); + SinglePassRunAndCheck(text, text, true, true); } TEST_F(IfConversionTest, TooManyPredecessors) { @@ -296,7 +389,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(text, text, true, true); + SinglePassRunAndCheck(text, text, true, true); } TEST_F(IfConversionTest, NoCodeMotion) { @@ -326,7 +419,56 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(text, text, true, true); + SinglePassRunAndCheck(text, text, true, true); } -} // anonymous namespace +TEST_F(IfConversionTest, NoCodeMotionImmovableInst) { + const std::string text = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "func" %2 +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint +%2 = OpVariable %_ptr_Output_uint Output +%8 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%1 = OpFunction %void None %8 +%11 = OpLabel +OpSelectionMerge %12 None +OpBranchConditional %true %13 %14 +%13 = OpLabel +OpSelectionMerge %15 None +OpBranchConditional %true %16 %15 +%16 = OpLabel +%17 = OpIAdd %uint %uint_0 %uint_1 +OpBranch %15 +%15 = OpLabel +%18 = OpPhi %uint %uint_0 %13 %17 %16 +%19 = OpIAdd %uint %18 %uint_1 +OpBranch %12 +%14 = OpLabel +OpSelectionMerge %20 None +OpBranchConditional %true %21 %20 +%21 = OpLabel +%22 = OpIAdd %uint %uint_0 %uint_1 +OpBranch %20 +%20 = OpLabel +%23 = OpPhi %uint %uint_0 %14 %22 %21 +%24 = OpIAdd %uint %23 %uint_1 +OpBranch %12 +%12 = OpLabel +%25 = OpPhi %uint %24 %20 %19 %15 +OpStore %2 %25 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(text, text, true, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/inline_opaque_test.cpp b/3rdparty/spirv-tools/test/opt/inline_opaque_test.cpp index d3588f7ea..d10913aec 100644 --- a/3rdparty/spirv-tools/test/opt/inline_opaque_test.cpp +++ b/3rdparty/spirv-tools/test/opt/inline_opaque_test.cpp @@ -13,13 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using InlineOpaqueTest = PassTest<::testing::Test>; TEST_F(InlineOpaqueTest, InlineCallWithStructArgContainingSampledImage) { @@ -72,7 +74,7 @@ OpDecorate %sampler15 DescriptorSet 0 const std::string before = R"(%main = OpFunction %void None %12 %28 = OpLabel -%s0 = OpVariable %_ptr_Function_S_t Function +%s0 = OpVariable %_ptr_Function_S_t Function %param = OpVariable %_ptr_Function_S_t Function %29 = OpLoad %v2float %texCoords %30 = OpAccessChain %_ptr_Function_v2float %s0 %int_0 @@ -80,7 +82,7 @@ OpStore %30 %29 %31 = OpLoad %18 %sampler15 %32 = OpAccessChain %_ptr_Function_18 %s0 %int_2 OpStore %32 %31 -%33 = OpLoad %S_t %s0 +%33 = OpLoad %S_t %s0 OpStore %param %33 %34 = OpFunctionCall %void %foo_struct_S_t_vf2_vf21_ %param OpReturn @@ -124,7 +126,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs + before + post_defs, predefs + after + post_defs, true, true); } @@ -172,7 +174,7 @@ OpDecorate %sampler16 DescriptorSet 0 const std::string before = R"(%main = OpFunction %void None %9 %24 = OpLabel -%25 = OpVariable %_ptr_Function_20 Function +%25 = OpVariable %_ptr_Function_20 Function %26 = OpFunctionCall %20 %foo_ OpStore %25 %26 %27 = OpLoad %20 %25 @@ -214,7 +216,7 @@ OpReturnValue %33 OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs + before + post_defs, predefs + after + post_defs, true, true); } @@ -271,7 +273,7 @@ OpDecorate %sampler15 DescriptorSet 0 const std::string before = R"(%main2 = OpFunction %void None %13 %29 = OpLabel -%s0 = OpVariable %_ptr_Function_S_t Function +%s0 = OpVariable %_ptr_Function_S_t Function %param = OpVariable %_ptr_Function_S_t Function %30 = OpLoad %v2float %texCoords %31 = OpAccessChain %_ptr_Function_v2float %s0 %int_0 @@ -279,7 +281,7 @@ OpStore %31 %30 %32 = OpLoad %19 %sampler15 %33 = OpAccessChain %_ptr_Function_19 %s0 %int_2 OpStore %33 %32 -%34 = OpLoad %S_t %s0 +%34 = OpLoad %S_t %s0 OpStore %param %34 %35 = OpFunctionCall %void %foo_struct_S_t_vf2_vf21_ %param OpReturn @@ -328,7 +330,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs + before + post_defs, predefs + after + post_defs, true, true); } @@ -402,7 +404,9 @@ OpReturnValue %31 OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, true); + SinglePassRunAndCheck(assembly, assembly, true, true); } -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/inline_test.cpp b/3rdparty/spirv-tools/test/opt/inline_test.cpp index 8f3675ad1..4eab77da4 100644 --- a/3rdparty/spirv-tools/test/opt/inline_test.cpp +++ b/3rdparty/spirv-tools/test/opt/inline_test.cpp @@ -13,13 +13,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include +#include +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using InlineTest = PassTest<::testing::Test>; TEST_F(InlineTest, Simple) { @@ -126,7 +130,7 @@ TEST_F(InlineTest, Simple) { "OpFunctionEnd", // clang-format on }; - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), /* skip_nop = */ false, /* do_validate = */ true); @@ -176,14 +180,14 @@ TEST_F(InlineTest, Nested) { "%15 = OpTypeFunction %void", "%float = OpTypeFloat 32", "%_ptr_Function_float = OpTypePointer Function %float", - "%18 = OpTypeFunction %float %_ptr_Function_float %_ptr_Function_float", - "%v4float = OpTypeVector %float 4", + "%18 = OpTypeFunction %float %_ptr_Function_float %_ptr_Function_float", + "%v4float = OpTypeVector %float 4", "%_ptr_Function_v4float = OpTypePointer Function %v4float", "%21 = OpTypeFunction %float %_ptr_Function_v4float", - "%uint = OpTypeInt 32 0", + "%uint = OpTypeInt 32 0", "%uint_0 = OpConstant %uint 0", "%uint_1 = OpConstant %uint 1", - "%uint_2 = OpConstant %uint 2", + "%uint_2 = OpConstant %uint 2", "%_ptr_Input_v4float = OpTypePointer Input %v4float", "%BaseColor = OpVariable %_ptr_Input_v4float Input", "%_ptr_Output_v4float = OpTypePointer Output %v4float", @@ -250,7 +254,7 @@ TEST_F(InlineTest, Nested) { "%48 = OpVariable %_ptr_Function_float Function", "%color = OpVariable %_ptr_Function_v4float Function", "%param_1 = OpVariable %_ptr_Function_v4float Function", - "%29 = OpLoad %v4float %BaseColor", + "%29 = OpLoad %v4float %BaseColor", "OpStore %param_1 %29", "%49 = OpAccessChain %_ptr_Function_float %param_1 %uint_0", "%50 = OpLoad %float %49", @@ -266,7 +270,7 @@ TEST_F(InlineTest, Nested) { "%60 = OpFMul %float %58 %59", "OpStore %57 %60", "%56 = OpLoad %float %57", - "OpStore %48 %56", + "OpStore %48 %56", "%30 = OpLoad %float %48", "%31 = OpCompositeConstruct %v4float %30 %30 %30 %30", "OpStore %color %31", @@ -276,7 +280,7 @@ TEST_F(InlineTest, Nested) { "OpFunctionEnd", // clang-format on }; - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), /* skip_nop = */ false, /* do_validate = */ true); @@ -405,7 +409,7 @@ TEST_F(InlineTest, InOutParameter) { "OpFunctionEnd", // clang-format on }; - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), /* skip_nop = */ false, /* do_validate = */ true); @@ -541,7 +545,7 @@ TEST_F(InlineTest, BranchInCallee) { "OpFunctionEnd", // clang-format on }; - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), /* skip_nop = */ false, /* do_validate = */ true); @@ -615,7 +619,7 @@ TEST_F(InlineTest, PhiAfterCall) { "OpStore %r %44", "%45 = OpLoad %float %r", "%46 = OpFOrdLessThan %bool %45 %float_0", - "OpSelectionMerge %47 None", + "OpSelectionMerge %47 None", "OpBranchConditional %46 %48 %47", "%48 = OpLabel", "%49 = OpLoad %float %r", @@ -643,7 +647,7 @@ TEST_F(InlineTest, PhiAfterCall) { "OpStore %param %30", "%31 = OpFunctionCall %float %foo_f1_ %param", "%32 = OpFOrdGreaterThan %bool %31 %float_2", - "OpSelectionMerge %33 None", + "OpSelectionMerge %33 None", "OpBranchConditional %32 %34 %33", "%34 = OpLabel", "%35 = OpAccessChain %_ptr_Function_float %color %uint_1", @@ -654,7 +658,7 @@ TEST_F(InlineTest, PhiAfterCall) { "OpBranch %33", "%33 = OpLabel", "%39 = OpPhi %bool %32 %27 %38 %34", - "OpSelectionMerge %40 None", + "OpSelectionMerge %40 None", "OpBranchConditional %39 %41 %40", "%41 = OpLabel", "OpStore %color %25", @@ -694,7 +698,7 @@ TEST_F(InlineTest, PhiAfterCall) { "%60 = OpFNegate %float %59", "OpStore %52 %60", "OpBranch %57", - "%57 = OpLabel", + "%57 = OpLabel", "%61 = OpLoad %float %52", "OpStore %53 %61", "%31 = OpLoad %float %53", @@ -710,7 +714,7 @@ TEST_F(InlineTest, PhiAfterCall) { "%65 = OpLoad %float %62", "%66 = OpFOrdLessThan %bool %65 %float_0", "OpSelectionMerge %67 None", - "OpBranchConditional %66 %68 %67", + "OpBranchConditional %66 %68 %67", "%68 = OpLabel", "%69 = OpLoad %float %62", "%70 = OpFNegate %float %69", @@ -736,7 +740,7 @@ TEST_F(InlineTest, PhiAfterCall) { "OpFunctionEnd", // clang-format on }; - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), /* skip_nop = */ false, /* do_validate = */ true); @@ -933,7 +937,7 @@ TEST_F(InlineTest, OpSampledImageOutOfBlock) { "OpFunctionEnd", // clang-format on }; - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), /* skip_nop = */ false, /* do_validate = */ true); @@ -1139,7 +1143,7 @@ TEST_F(InlineTest, OpImageOutOfBlock) { "OpFunctionEnd", // clang-format on }; - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), /* skip_nop = */ false, /* do_validate = */ true); @@ -1207,7 +1211,7 @@ TEST_F(InlineTest, OpImageAndOpSampledImageOutOfBlock) { "%uint = OpTypeInt 32 0", "%uint_0 = OpConstant %uint 0", "%float_0 = OpConstant %float 0", - "%bool = OpTypeBool", + "%bool = OpTypeBool", "%26 = OpTypeImage %float 2D 0 0 0 1 Unknown", "%_ptr_UniformConstant_26 = OpTypePointer UniformConstant %26", "%t2D = OpVariable %_ptr_UniformConstant_26 UniformConstant", @@ -1221,7 +1225,7 @@ TEST_F(InlineTest, OpImageAndOpSampledImageOutOfBlock) { "%_ptr_Input_v4float = OpTypePointer Input %v4float", "%BaseColor = OpVariable %_ptr_Input_v4float Input", "%samp2 = OpVariable %_ptr_UniformConstant_28 UniformConstant", - "%float_0_5 = OpConstant %float 0.5", + "%float_0_5 = OpConstant %float 0.5", "%36 = OpConstantComposite %v2float %float_0_5 %float_0_5", "%_ptr_Output_v4float = OpTypePointer Output %v4float", "%FragColor = OpVariable %_ptr_Output_v4float Output", @@ -1345,7 +1349,7 @@ TEST_F(InlineTest, OpImageAndOpSampledImageOutOfBlock) { "OpFunctionEnd", // clang-format on }; - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), /* skip_nop = */ false, /* do_validate = */ true); @@ -1473,9 +1477,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before + nonEntryFuncs, predefs + after + nonEntryFuncs, false, - true); + SinglePassRunAndCheck(predefs + before + nonEntryFuncs, + predefs + after + nonEntryFuncs, + false, true); } TEST_F(InlineTest, EarlyReturnNotAppearingLastInFunctionInlined) { @@ -1543,9 +1547,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + nonEntryFuncs + before, predefs + nonEntryFuncs + after, false, - true); + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); } TEST_F(InlineTest, ForwardReferencesInPhiInlined) { @@ -1632,9 +1636,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + nonEntryFuncs + before, predefs + nonEntryFuncs + after, false, - true); + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); } TEST_F(InlineTest, EarlyReturnInLoopIsNotInlined) { @@ -1729,8 +1733,7 @@ OpReturnValue %41 OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, false, - true); + SinglePassRunAndCheck(assembly, assembly, false, true); } TEST_F(InlineTest, ExternalFunctionIsNotInlined) { @@ -1754,8 +1757,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, false, - true); + SinglePassRunAndCheck(assembly, assembly, false, true); } TEST_F(InlineTest, SingleBlockLoopCallsMultiBlockCallee) { @@ -1826,9 +1828,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + nonEntryFuncs + before, predefs + nonEntryFuncs + after, false, - true); + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); } TEST_F(InlineTest, MultiBlockLoopHeaderCallsMultiBlockCallee) { @@ -1903,9 +1905,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + nonEntryFuncs + before, predefs + nonEntryFuncs + after, false, - true); + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); } TEST_F(InlineTest, SingleBlockLoopCallsMultiBlockCalleeHavingSelectionMerge) { @@ -1992,9 +1994,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + nonEntryFuncs + before, predefs + nonEntryFuncs + after, false, - true); + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); } TEST_F(InlineTest, @@ -2073,9 +2075,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + nonEntryFuncs + before, predefs + nonEntryFuncs + after, false, - true); + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); } TEST_F( @@ -2164,9 +2166,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + nonEntryFuncs + before, predefs + nonEntryFuncs + after, false, - true); + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); } TEST_F(InlineTest, CalleeWithMultiReturnAndPhiRequiresEntryBlockRemapping) { @@ -2246,9 +2248,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + nonEntryFuncs + before, predefs + nonEntryFuncs + after, false, - true); + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); } TEST_F(InlineTest, Decorated1) { @@ -2370,9 +2372,9 @@ OpFunctionEnd OpReturnValue %9 OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before + nonEntryFuncs, predefs + after + nonEntryFuncs, false, - true); + SinglePassRunAndCheck(predefs + before + nonEntryFuncs, + predefs + after + nonEntryFuncs, + false, true); } TEST_F(InlineTest, Decorated2) { @@ -2494,9 +2496,9 @@ OpFunctionEnd OpReturnValue %31 OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before + nonEntryFuncs, predefs + after + nonEntryFuncs, false, - true); + SinglePassRunAndCheck(predefs + before + nonEntryFuncs, + predefs + after + nonEntryFuncs, + false, true); } TEST_F(InlineTest, DeleteName) { @@ -2546,7 +2548,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before, after, false, true); + SinglePassRunAndCheck(before, after, false, true); } TEST_F(InlineTest, SetParent) { @@ -2575,18 +2577,271 @@ TEST_F(InlineTest, SetParent) { OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::InlineExhaustivePass pass; + InlineExhaustivePass pass; pass.Run(context.get()); - for (ir::Function& func : *context->module()) { - for (ir::BasicBlock& bb : func) { + for (Function& func : *context->module()) { + for (BasicBlock& bb : func) { EXPECT_TRUE(bb.GetParent() == &func); } } } +#ifdef SPIRV_EFFCEE +TEST_F(InlineTest, OpKill) { + const std::string text = R"( +; CHECK: OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpKill +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpFunctionEnd +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%voidfuncty = OpTypeFunction %void +%main = OpFunction %void None %voidfuncty +%1 = OpLabel +%2 = OpFunctionCall %void %func +OpReturn +OpFunctionEnd +%func = OpFunction %void None %voidfuncty +%3 = OpLabel +OpKill +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(InlineTest, OpKillWithTrailingInstructions) { + const std::string text = R"( +; CHECK: OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: [[var:%\w+]] = OpVariable +; CHECK-NEXT: OpKill +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpStore [[var]] +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpFunctionEnd +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%bool_func_ptr = OpTypePointer Function %bool +%voidfuncty = OpTypeFunction %void +%main = OpFunction %void None %voidfuncty +%1 = OpLabel +%2 = OpVariable %bool_func_ptr Function +%3 = OpFunctionCall %void %func +OpStore %2 %true +OpReturn +OpFunctionEnd +%func = OpFunction %void None %voidfuncty +%4 = OpLabel +OpKill +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(InlineTest, OpKillInIf) { + const std::string text = R"( +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK: [[var:%\w+]] = OpVariable +; CHECK-NEXT: [[ld:%\w+]] = OpLoad {{%\w+}} [[var]] +; CHECK-NEXT: OpBranch [[label:%\w+]] +; CHECK-NEXT: [[label]] = OpLabel +; CHECK-NEXT: OpLoopMerge [[loop_merge:%\w+]] [[continue:%\w+]] None +; CHECK-NEXT: OpBranch [[label:%\w+]] +; CHECK-NEXT: [[label]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[sel_merge:%\w+]] None +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[kill_label:%\w+]] [[label:%\w+]] +; CHECK-NEXT: [[kill_label]] = OpLabel +; CHECK-NEXT: OpKill +; CHECK-NEXT: [[label]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_merge]] +; CHECK-NEXT: [[sel_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_merge]] +; CHECK-NEXT: [[continue]] = OpLabel +; CHECK-NEXT: OpBranchConditional +; CHECK-NEXT: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpStore [[var]] [[ld]] +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpFunctionEnd +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%bool_func_ptr = OpTypePointer Function %bool +%voidfuncty = OpTypeFunction %void +%main = OpFunction %void None %voidfuncty +%1 = OpLabel +%2 = OpVariable %bool_func_ptr Function +%3 = OpLoad %bool %2 +%4 = OpFunctionCall %void %func +OpStore %2 %3 +OpReturn +OpFunctionEnd +%func = OpFunction %void None %voidfuncty +%5 = OpLabel +OpSelectionMerge %6 None +OpBranchConditional %true %7 %8 +%7 = OpLabel +OpKill +%8 = OpLabel +OpReturn +%6 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(InlineTest, OpKillInLoop) { + const std::string text = R"( +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK: [[var:%\w+]] = OpVariable +; CHECK-NEXT: [[ld:%\w+]] = OpLoad {{%\w+}} [[var]] +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK-NEXT: [[loop]] = OpLabel +; CHECK-NEXT: OpLoopMerge [[loop_merge:%\w+]] [[continue:%\w+]] None +; CHECK-NEXT: OpBranch [[label:%\w+]] +; CHECK-NEXT: [[label]] = OpLabel +; CHECK-NEXT: OpKill +; CHECK-NEXT: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[label:%\w+]] +; CHECK-NEXT: [[continue]] = OpLabel +; CHECK-NEXT: OpBranch [[loop]] +; CHECK-NEXT: [[label]] = OpLabel +; CHECK-NEXT: OpStore [[var]] [[ld]] +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpFunctionEnd +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%voidfuncty = OpTypeFunction %void +%bool_func_ptr = OpTypePointer Function %bool +%main = OpFunction %void None %voidfuncty +%1 = OpLabel +%2 = OpVariable %bool_func_ptr Function +%3 = OpLoad %bool %2 +%4 = OpFunctionCall %void %func +OpStore %2 %3 +OpReturn +OpFunctionEnd +%func = OpFunction %void None %voidfuncty +%5 = OpLabel +OpBranch %10 +%10 = OpLabel +OpLoopMerge %6 %7 None +OpBranch %8 +%8 = OpLabel +OpKill +%6 = OpLabel +OpReturn +%7 = OpLabel +OpBranch %10 +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(InlineTest, OpVariableWithInit) { + // Check that there is a store that corresponds to the initializer. This + // test makes sure that is a store to the variable in the loop and before any + // load. + const std::string text = R"( +; CHECK: OpFunction +; CHECK-NOT: OpFunctionEnd +; CHECK: [[var:%\w+]] = OpVariable %_ptr_Function_float Function %float_0 +; CHECK: OpLoopMerge [[outer_merge:%\w+]] +; CHECK-NOT: OpLoad %float [[var]] +; CHECK: OpStore [[var]] %float_0 +; CHECK: OpFunctionEnd + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %o + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpDecorate %o Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %7 = OpTypeFunction %float +%_ptr_Function_float = OpTypePointer Function %float + %float_0 = OpConstant %float 0 + %bool = OpTypeBool + %float_1 = OpConstant %float 1 +%_ptr_Output_float = OpTypePointer Output %float + %o = OpVariable %_ptr_Output_float Output + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%_ptr_Input_int = OpTypePointer Input %int + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %int_2 = OpConstant %int 2 + %main = OpFunction %void None %3 + %5 = OpLabel + OpStore %o %float_0 + OpBranch %34 + %34 = OpLabel + %39 = OpPhi %int %int_0 %5 %47 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %41 = OpSLessThan %bool %39 %int_2 + OpBranchConditional %41 %35 %36 + %35 = OpLabel + %42 = OpFunctionCall %float %foo_ + %43 = OpLoad %float %o + %44 = OpFAdd %float %43 %42 + OpStore %o %44 + OpBranch %37 + %37 = OpLabel + %47 = OpIAdd %int %39 %int_1 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + %foo_ = OpFunction %float None %7 + %9 = OpLabel + %n = OpVariable %_ptr_Function_float Function %float_0 + %13 = OpLoad %float %n + %15 = OpFOrdEqual %bool %13 %float_0 + OpSelectionMerge %17 None + OpBranchConditional %15 %16 %17 + %16 = OpLabel + %19 = OpLoad %float %n + %20 = OpFAdd %float %19 %float_1 + OpStore %n %20 + OpBranch %17 + %17 = OpLabel + %21 = OpLoad %float %n + OpReturnValue %21 + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} +#endif + // TODO(greg-lunarg): Add tests to verify handling of these cases: // // Empty modules @@ -2609,4 +2864,6 @@ TEST_F(InlineTest, SetParent) { // behaviour. // SampledImage after function call. It is not cloned or changed. -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/insert_extract_elim_test.cpp b/3rdparty/spirv-tools/test/opt/insert_extract_elim_test.cpp index afb0036f7..c5169750b 100644 --- a/3rdparty/spirv-tools/test/opt/insert_extract_elim_test.cpp +++ b/3rdparty/spirv-tools/test/opt/insert_extract_elim_test.cpp @@ -13,13 +13,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include "source/opt/simplification_pass.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using InsertExtractElimTest = PassTest<::testing::Test>; TEST_F(InsertExtractElimTest, Simple) { @@ -99,8 +102,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(InsertExtractElimTest, OptimizeAcrossNonConflictingInsert) { @@ -184,8 +187,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(InsertExtractElimTest, OptimizeOpaque) { @@ -235,9 +238,9 @@ OpDecorate %sampler15 DescriptorSet 0 const std::string before = R"(%main = OpFunction %void None %9 %25 = OpLabel -%s0 = OpVariable %_ptr_Function_S_t Function +%s0 = OpVariable %_ptr_Function_S_t Function %26 = OpLoad %v2float %texCoords -%27 = OpLoad %S_t %s0 +%27 = OpLoad %S_t %s0 %28 = OpCompositeInsert %S_t %26 %27 0 %29 = OpLoad %15 %sampler15 %30 = OpCompositeInsert %S_t %29 %28 2 @@ -266,8 +269,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(InsertExtractElimTest, OptimizeNestedStruct) { @@ -395,8 +398,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(InsertExtractElimTest, ConflictingInsertPreventsOptimization) { @@ -463,8 +466,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, - true); + SinglePassRunAndCheck(assembly, assembly, true, true); } TEST_F(InsertExtractElimTest, ConflictingInsertPreventsOptimization2) { @@ -580,15 +582,15 @@ OpFunctionEnd %24 = OpCompositeInsert %S_t %float_1 %23 1 1 %25 = OpLoad %v4float %BaseColor %26 = OpCompositeInsert %S_t %25 %24 1 -%27 = OpCompositeExtract %float %26 1 1 +%27 = OpCompositeExtract %float %25 1 %28 = OpCompositeConstruct %v4float %27 %float_0 %float_0 %float_0 OpStore %gl_FragColor %28 OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - before_predefs + before, after_predefs + after, true, true); + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); } TEST_F(InsertExtractElimTest, MixWithConstants) { @@ -689,8 +691,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); } TEST_F(InsertExtractElimTest, VectorShuffle1) { @@ -713,7 +715,7 @@ TEST_F(InsertExtractElimTest, VectorShuffle1) { // OutColor = vec4(v.y); // } - const std::string predefs = + const std::string predefs_before = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -743,6 +745,10 @@ OpDecorate %OutColor Location 0 %_ptr_Function_float = OpTypePointer Function %float )"; + const std::string predefs_after = predefs_before + + "%24 = OpConstantComposite %v4float " + "%float_1 %float_1 %float_1 %float_1\n"; + const std::string before = R"(%main = OpFunction %void None %7 %17 = OpLabel @@ -764,14 +770,13 @@ OpFunctionEnd %19 = OpLoad %float %bc2 %20 = OpCompositeConstruct %v4float %18 %19 %float_0 %float_1 %21 = OpVectorShuffle %v4float %20 %20 2 3 0 1 -%23 = OpCompositeConstruct %v4float %float_1 %float_1 %float_1 %float_1 -OpStore %OutColor %23 +OpStore %OutColor %24 OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs_before + before, + predefs_after + after, true, true); } TEST_F(InsertExtractElimTest, VectorShuffle2) { @@ -796,7 +801,7 @@ TEST_F(InsertExtractElimTest, VectorShuffle2) { // OutColor = vec4(v.y); // } - const std::string predefs = + const std::string predefs_before = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -824,6 +829,37 @@ OpDecorate %OutColor Location 0 %OutColor = OpVariable %_ptr_Output_v4float Output %uint = OpTypeInt 32 0 %_ptr_Function_float = OpTypePointer Function %float +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %bc %bc2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %bc "bc" +OpName %bc2 "bc2" +OpName %OutColor "OutColor" +OpDecorate %bc Location 0 +OpDecorate %bc2 Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_float = OpTypePointer Input %float +%bc = OpVariable %_ptr_Input_float Input +%bc2 = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%24 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 )"; const std::string before = @@ -847,17 +883,18 @@ OpFunctionEnd %19 = OpLoad %float %bc2 %20 = OpCompositeConstruct %v4float %18 %19 %float_0 %float_1 %21 = OpVectorShuffle %v4float %20 %20 2 7 0 1 -%23 = OpCompositeConstruct %v4float %float_1 %float_1 %float_1 %float_1 -OpStore %OutColor %23 +OpStore %OutColor %24 OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs_before + before, + predefs_after + after, true, true); } // TODO(greg-lunarg): Add tests to verify handling of these cases: // -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/instruction_list_test.cpp b/3rdparty/spirv-tools/test/opt/instruction_list_test.cpp index b8dec2f77..e745790a3 100644 --- a/3rdparty/spirv-tools/test/opt/instruction_list_test.cpp +++ b/3rdparty/spirv-tools/test/opt/instruction_list_test.cpp @@ -14,18 +14,18 @@ #include #include +#include #include #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "source/opt/instruction.h" +#include "source/opt/instruction_list.h" -#include "opt/instruction.h" -#include "opt/instruction_list.h" - +namespace spvtools { +namespace opt { namespace { -using Instruction = spvtools::ir::Instruction; -using InstructionList = spvtools::ir::InstructionList; using ::testing::ContainerEq; using ::testing::ElementsAre; using InstructionListTest = ::testing::Test; @@ -109,4 +109,7 @@ TEST(InstructionListTest, InsertBefore2) { } EXPECT_THAT(output, ContainerEq(created_instructions)); } + } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/instruction_test.cpp b/3rdparty/spirv-tools/test/opt/instruction_test.cpp index 0a632a98b..2ace6b8ac 100644 --- a/3rdparty/spirv-tools/test/opt/instruction_test.cpp +++ b/3rdparty/spirv-tools/test/opt/instruction_test.cpp @@ -12,27 +12,28 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "opt/instruction.h" -#include "opt/ir_context.h" +#include +#include +#include #include "gmock/gmock.h" - -#include "pass_fixture.h" -#include "pass_utils.h" +#include "source/opt/instruction.h" +#include "source/opt/ir_context.h" #include "spirv-tools/libspirv.h" -#include "unit_spirv.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +#include "test/unit_spirv.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; -using ir::Instruction; -using ir::IRContext; -using ir::Operand; using spvtest::MakeInstruction; using ::testing::Eq; using DescriptorTypeTest = PassTest<::testing::Test>; using OpaqueTypeTest = PassTest<::testing::Test>; using GetBaseTest = PassTest<::testing::Test>; +using ValidBasePointerTest = PassTest<::testing::Test>; TEST(InstructionTest, CreateTrivial) { Instruction empty; @@ -316,7 +317,7 @@ TEST_F(DescriptorTypeTest, StorageImage) { OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); Instruction* type = context->get_def_use_mgr()->GetDef(8); EXPECT_TRUE(type->IsVulkanStorageImage()); @@ -352,7 +353,7 @@ TEST_F(DescriptorTypeTest, SampledImage) { OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); Instruction* type = context->get_def_use_mgr()->GetDef(8); EXPECT_FALSE(type->IsVulkanStorageImage()); @@ -388,7 +389,7 @@ TEST_F(DescriptorTypeTest, StorageTexelBuffer) { OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); Instruction* type = context->get_def_use_mgr()->GetDef(8); EXPECT_FALSE(type->IsVulkanStorageImage()); @@ -427,7 +428,7 @@ TEST_F(DescriptorTypeTest, StorageBuffer) { OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); Instruction* type = context->get_def_use_mgr()->GetDef(10); EXPECT_FALSE(type->IsVulkanStorageImage()); @@ -466,7 +467,7 @@ TEST_F(DescriptorTypeTest, UniformBuffer) { OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); Instruction* type = context->get_def_use_mgr()->GetDef(10); EXPECT_FALSE(type->IsVulkanStorageImage()); @@ -506,7 +507,7 @@ TEST_F(DescriptorTypeTest, NonWritableIsReadOnly) { OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); Instruction* variable = context->get_def_use_mgr()->GetDef(3); EXPECT_TRUE(variable->IsReadOnlyVariable()); @@ -533,7 +534,7 @@ TEST_F(OpaqueTypeTest, BaseOpaqueTypesShader) { OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); Instruction* image_type = context->get_def_use_mgr()->GetDef(6); EXPECT_TRUE(image_type->IsOpaqueType()); @@ -571,7 +572,7 @@ TEST_F(OpaqueTypeTest, OpaqueStructTypes) { OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); for (int i = 7; i <= 10; i++) { Instruction* type = context->get_def_use_mgr()->GetDef(i); @@ -614,7 +615,7 @@ TEST_F(GetBaseTest, SampleImage) { OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); Instruction* load = context->get_def_use_mgr()->GetDef(21); Instruction* base = context->get_def_use_mgr()->GetDef(20); @@ -649,10 +650,457 @@ TEST_F(GetBaseTest, ImageRead) { OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); Instruction* load = context->get_def_use_mgr()->GetDef(14); Instruction* base = context->get_def_use_mgr()->GetDef(13); EXPECT_TRUE(load->GetBaseAddress() == base); } -} // anonymous namespace + +TEST_F(ValidBasePointerTest, OpSelectBadNoVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpVariable %4 StorageBuffer +%6 = OpTypeFunction %2 +%7 = OpTypeBool +%8 = OpConstantTrue %7 +%1 = OpFunction %2 None %6 +%9 = OpLabel +%10 = OpSelect %4 %8 %5 %5 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* select = context->get_def_use_mgr()->GetDef(10); + EXPECT_NE(select, nullptr); + EXPECT_FALSE(select->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpSelectBadNoVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpVariable %4 Workgroup +%6 = OpTypeFunction %2 +%7 = OpTypeBool +%8 = OpConstantTrue %7 +%1 = OpFunction %2 None %6 +%9 = OpLabel +%10 = OpSelect %4 %8 %5 %5 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* select = context->get_def_use_mgr()->GetDef(10); + EXPECT_NE(select, nullptr); + EXPECT_FALSE(select->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpSelectGoodVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpVariable %4 StorageBuffer +%6 = OpTypeFunction %2 +%7 = OpTypeBool +%8 = OpConstantTrue %7 +%1 = OpFunction %2 None %6 +%9 = OpLabel +%10 = OpSelect %4 %8 %5 %5 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* select = context->get_def_use_mgr()->GetDef(10); + EXPECT_NE(select, nullptr); + EXPECT_TRUE(select->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpSelectGoodVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointers +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpVariable %4 Workgroup +%6 = OpTypeFunction %2 +%7 = OpTypeBool +%8 = OpConstantTrue %7 +%1 = OpFunction %2 None %6 +%9 = OpLabel +%10 = OpSelect %4 %8 %5 %5 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* select = context->get_def_use_mgr()->GetDef(10); + EXPECT_NE(select, nullptr); + EXPECT_TRUE(select->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpConstantNullBadNoVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(5); + EXPECT_NE(null_inst, nullptr); + EXPECT_FALSE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpConstantNullBadNoVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(5); + EXPECT_NE(null_inst, nullptr); + EXPECT_FALSE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpConstantNullGoodVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(5); + EXPECT_NE(null_inst, nullptr); + EXPECT_TRUE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpConstantNullGoodVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointers +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(5); + EXPECT_NE(null_inst, nullptr); + EXPECT_TRUE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpPhiBadNoVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpVariable %4 StorageBuffer +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpBranch %8 +%8 = OpLabel +%9 = OpPhi %4 %5 %7 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* phi = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(phi, nullptr); + EXPECT_FALSE(phi->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpPhiBadNoVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpVariable %4 Workgroup +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpBranch %8 +%8 = OpLabel +%9 = OpPhi %4 %5 %7 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* phi = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(phi, nullptr); + EXPECT_FALSE(phi->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpPhiGoodVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpVariable %4 StorageBuffer +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpBranch %8 +%8 = OpLabel +%9 = OpPhi %4 %5 %7 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* phi = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(phi, nullptr); + EXPECT_TRUE(phi->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpPhiGoodVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointers +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpVariable %4 Workgroup +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpBranch %8 +%8 = OpLabel +%9 = OpPhi %4 %5 %7 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* phi = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(phi, nullptr); + EXPECT_TRUE(phi->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpFunctionCallBadNoVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%7 = OpTypeFunction %4 +%1 = OpFunction %2 None %6 +%8 = OpLabel +%9 = OpFunctionCall %4 %10 +OpReturn +OpFunctionEnd +%10 = OpFunction %4 None %7 +%11 = OpLabel +OpReturnValue %5 +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(null_inst, nullptr); + EXPECT_FALSE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpFunctionCallBadNoVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%7 = OpTypeFunction %4 +%1 = OpFunction %2 None %6 +%8 = OpLabel +%9 = OpFunctionCall %4 %10 +OpReturn +OpFunctionEnd +%10 = OpFunction %4 None %7 +%11 = OpLabel +OpReturnValue %5 +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(null_inst, nullptr); + EXPECT_FALSE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpFunctionCallGoodVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%7 = OpTypeFunction %4 +%1 = OpFunction %2 None %6 +%8 = OpLabel +%9 = OpFunctionCall %4 %10 +OpReturn +OpFunctionEnd +%10 = OpFunction %4 None %7 +%11 = OpLabel +OpReturnValue %5 +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(null_inst, nullptr); + EXPECT_TRUE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpFunctionCallGoodVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointers +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%7 = OpTypeFunction %4 +%1 = OpFunction %2 None %6 +%8 = OpLabel +%9 = OpFunctionCall %4 %10 +OpReturn +OpFunctionEnd +%10 = OpFunction %4 None %7 +%11 = OpLabel +OpReturnValue %5 +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(null_inst, nullptr); + EXPECT_TRUE(null_inst->IsValidBasePointer()); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/ir_builder.cpp b/3rdparty/spirv-tools/test/opt/ir_builder.cpp index 6096b49db..7eeb86dd3 100644 --- a/3rdparty/spirv-tools/test/opt/ir_builder.cpp +++ b/3rdparty/spirv-tools/test/opt/ir_builder.cpp @@ -12,31 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include #include +#include +#include #include +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/basic_block.h" +#include "source/opt/build_module.h" +#include "source/opt/instruction.h" +#include "source/opt/ir_builder.h" +#include "source/opt/type_manager.h" +#include "spirv-tools/libspirv.hpp" + #ifdef SPIRV_EFFCEE #include "effcee/effcee.h" #endif -#include "opt/basic_block.h" -#include "opt/ir_builder.h" - -#include "opt/build_module.h" -#include "opt/instruction.h" -#include "opt/type_manager.h" -#include "spirv-tools/libspirv.hpp" - +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; -using ir::IRContext; -using Analysis = IRContext::Analysis; - #ifdef SPIRV_EFFCEE +using Analysis = IRContext::Analysis; using IRBuilderTest = ::testing::Test; bool Validate(const std::vector& bin) { @@ -51,7 +51,7 @@ bool Validate(const std::vector& bin) { return error == 0; } -void Match(const std::string& original, ir::IRContext* context, +void Match(const std::string& original, IRContext* context, bool do_validation = true) { std::vector bin; context->module()->ToBinary(&bin, true); @@ -112,18 +112,18 @@ TEST_F(IRBuilderTest, TestInsnAddition) { )"; { - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - ir::BasicBlock* bb = context->cfg()->block(18); + BasicBlock* bb = context->cfg()->block(18); // Build managers. context->get_def_use_mgr(); context->get_instr_block(nullptr); - opt::InstructionBuilder builder(context.get(), &*bb->begin()); - ir::Instruction* phi1 = builder.AddPhi(7, {9, 14}); - ir::Instruction* phi2 = builder.AddPhi(10, {16, 14}); + InstructionBuilder builder(context.get(), &*bb->begin()); + Instruction* phi1 = builder.AddPhi(7, {9, 14}); + Instruction* phi2 = builder.AddPhi(10, {16, 14}); // Make sure the InstructionBuilder did not update the def/use manager. EXPECT_EQ(context->get_def_use_mgr()->GetDef(phi1->result_id()), nullptr); @@ -135,20 +135,19 @@ TEST_F(IRBuilderTest, TestInsnAddition) { } { - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); // Build managers. context->get_def_use_mgr(); context->get_instr_block(nullptr); - ir::BasicBlock* bb = context->cfg()->block(18); - opt::InstructionBuilder builder( + BasicBlock* bb = context->cfg()->block(18); + InstructionBuilder builder( context.get(), &*bb->begin(), - ir::IRContext::kAnalysisDefUse | - ir::IRContext::kAnalysisInstrToBlockMapping); - ir::Instruction* phi1 = builder.AddPhi(7, {9, 14}); - ir::Instruction* phi2 = builder.AddPhi(10, {16, 14}); + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + Instruction* phi1 = builder.AddPhi(7, {9, 14}); + Instruction* phi2 = builder.AddPhi(10, {16, 14}); // Make sure InstructionBuilder updated the def/use manager EXPECT_NE(context->get_def_use_mgr()->GetDef(phi1->result_id()), nullptr); @@ -197,28 +196,28 @@ TEST_F(IRBuilderTest, TestCondBranchAddition) { )"; { - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - ir::Function& fn = *context->module()->begin(); + Function& fn = *context->module()->begin(); - ir::BasicBlock& bb_merge = *fn.begin(); + BasicBlock& bb_merge = *fn.begin(); - fn.begin().InsertBefore(std::unique_ptr( - new ir::BasicBlock(std::unique_ptr(new ir::Instruction( + fn.begin().InsertBefore(std::unique_ptr( + new BasicBlock(std::unique_ptr(new Instruction( context.get(), SpvOpLabel, 0, context->TakeNextId(), {}))))); - ir::BasicBlock& bb_true = *fn.begin(); + BasicBlock& bb_true = *fn.begin(); { - opt::InstructionBuilder builder(context.get(), &*bb_true.begin()); + InstructionBuilder builder(context.get(), &*bb_true.begin()); builder.AddBranch(bb_merge.id()); } - fn.begin().InsertBefore(std::unique_ptr( - new ir::BasicBlock(std::unique_ptr(new ir::Instruction( + fn.begin().InsertBefore(std::unique_ptr( + new BasicBlock(std::unique_ptr(new Instruction( context.get(), SpvOpLabel, 0, context->TakeNextId(), {}))))); - ir::BasicBlock& bb_cond = *fn.begin(); + BasicBlock& bb_cond = *fn.begin(); - opt::InstructionBuilder builder(context.get(), &bb_cond); + InstructionBuilder builder(context.get(), &bb_cond); // This also test consecutive instruction insertion: merge selection + // branch. builder.AddConditionalBranch(9, bb_true.id(), bb_merge.id(), bb_merge.id()); @@ -251,12 +250,12 @@ OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); EXPECT_NE(nullptr, context); - opt::InstructionBuilder builder( - context.get(), &*context->module()->begin()->begin()->begin()); + InstructionBuilder builder(context.get(), + &*context->module()->begin()->begin()->begin()); EXPECT_NE(nullptr, builder.AddSelect(3u, 4u, 5u, 6u)); Match(text, context.get()); @@ -284,12 +283,12 @@ OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); EXPECT_NE(nullptr, context); - opt::InstructionBuilder builder( - context.get(), &*context->module()->begin()->begin()->begin()); + InstructionBuilder builder(context.get(), + &*context->module()->begin()->begin()->begin()); std::vector ids = {3u, 4u, 4u, 3u}; EXPECT_NE(nullptr, builder.AddCompositeConstruct(5u, ids)); @@ -317,12 +316,12 @@ OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); EXPECT_NE(nullptr, context); - opt::InstructionBuilder builder( - context.get(), &*context->module()->begin()->begin()->begin()); + InstructionBuilder builder(context.get(), + &*context->module()->begin()->begin()->begin()); EXPECT_NE(nullptr, builder.Add32BitUnsignedIntegerConstant(13)); EXPECT_NE(nullptr, builder.Add32BitSignedIntegerConstant(-1)); @@ -362,14 +361,14 @@ OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); EXPECT_NE(nullptr, context); - opt::InstructionBuilder builder( - context.get(), &*context->module()->begin()->begin()->begin()); - ir::Instruction* const_1 = builder.Add32BitUnsignedIntegerConstant(13); - ir::Instruction* const_2 = builder.Add32BitSignedIntegerConstant(-1); + InstructionBuilder builder(context.get(), + &*context->module()->begin()->begin()->begin()); + Instruction* const_1 = builder.Add32BitUnsignedIntegerConstant(13); + Instruction* const_2 = builder.Add32BitSignedIntegerConstant(-1); EXPECT_NE(nullptr, const_1); EXPECT_NE(nullptr, const_2); @@ -378,15 +377,15 @@ OpFunctionEnd EXPECT_EQ(const_1, builder.Add32BitUnsignedIntegerConstant(13)); EXPECT_EQ(const_2, builder.Add32BitSignedIntegerConstant(-1)); - ir::Instruction* const_3 = builder.Add32BitUnsignedIntegerConstant(1); - ir::Instruction* const_4 = builder.Add32BitSignedIntegerConstant(34); + Instruction* const_3 = builder.Add32BitUnsignedIntegerConstant(1); + Instruction* const_4 = builder.Add32BitSignedIntegerConstant(34); // Try adding different constants to make sure the type is reused. EXPECT_NE(nullptr, const_3); EXPECT_NE(nullptr, const_4); - ir::Instruction* const_5 = builder.Add32BitUnsignedIntegerConstant(0); - ir::Instruction* const_6 = builder.Add32BitSignedIntegerConstant(0); + Instruction* const_5 = builder.Add32BitUnsignedIntegerConstant(0); + Instruction* const_6 = builder.Add32BitSignedIntegerConstant(0); // Try adding 0 as both signed and unsigned. EXPECT_NE(nullptr, const_5); @@ -412,4 +411,6 @@ OpFunctionEnd #endif // SPIRV_EFFCEE -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/ir_context_test.cpp b/3rdparty/spirv-tools/test/opt/ir_context_test.cpp index ad851ed72..c64e5b04f 100644 --- a/3rdparty/spirv-tools/test/opt/ir_context_test.cpp +++ b/3rdparty/spirv-tools/test/opt/ir_context_test.cpp @@ -12,56 +12,68 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include #include +#include +#include +#include -#include "opt/ir_context.h" -#include "opt/pass.h" -#include "pass_fixture.h" -#include "pass_utils.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/ir_context.h" +#include "source/opt/pass.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; -using ir::IRContext; using Analysis = IRContext::Analysis; using ::testing::Each; -class DummyPassPreservesNothing : public opt::Pass { +class DummyPassPreservesNothing : public Pass { public: - DummyPassPreservesNothing(Status s) : opt::Pass(), status_to_return_(s) {} + DummyPassPreservesNothing(Status s) : Pass(), status_to_return_(s) {} + const char* name() const override { return "dummy-pass"; } - Status Process(IRContext*) override { return status_to_return_; } + Status Process() override { return status_to_return_; } + + private: Status status_to_return_; }; -class DummyPassPreservesAll : public opt::Pass { +class DummyPassPreservesAll : public Pass { public: - DummyPassPreservesAll(Status s) : opt::Pass(), status_to_return_(s) {} + DummyPassPreservesAll(Status s) : Pass(), status_to_return_(s) {} + const char* name() const override { return "dummy-pass"; } - Status Process(IRContext*) override { return status_to_return_; } - Status status_to_return_; - virtual Analysis GetPreservedAnalyses() override { + Status Process() override { return status_to_return_; } + + Analysis GetPreservedAnalyses() override { return Analysis(IRContext::kAnalysisEnd - 1); } + + private: + Status status_to_return_; }; -class DummyPassPreservesFirst : public opt::Pass { +class DummyPassPreservesFirst : public Pass { public: - DummyPassPreservesFirst(Status s) : opt::Pass(), status_to_return_(s) {} + DummyPassPreservesFirst(Status s) : Pass(), status_to_return_(s) {} + const char* name() const override { return "dummy-pass"; } - Status Process(IRContext*) override { return status_to_return_; } + Status Process() override { return status_to_return_; } + + Analysis GetPreservedAnalyses() override { return IRContext::kAnalysisBegin; } + + private: Status status_to_return_; - virtual Analysis GetPreservedAnalyses() override { - return IRContext::kAnalysisBegin; - } }; using IRContextTest = PassTest<::testing::Test>; TEST_F(IRContextTest, IndividualValidAfterBuild) { - std::unique_ptr module(new ir::Module()); + std::unique_ptr module(new Module()); IRContext localContext(SPV_ENV_UNIVERSAL_1_2, std::move(module), spvtools::MessageConsumer()); @@ -73,7 +85,7 @@ TEST_F(IRContextTest, IndividualValidAfterBuild) { } TEST_F(IRContextTest, AllValidAfterBuild) { - std::unique_ptr module = MakeUnique(); + std::unique_ptr module = MakeUnique(); IRContext localContext(SPV_ENV_UNIVERSAL_1_2, std::move(module), spvtools::MessageConsumer()); @@ -87,7 +99,7 @@ TEST_F(IRContextTest, AllValidAfterBuild) { } TEST_F(IRContextTest, AllValidAfterPassNoChange) { - std::unique_ptr module = MakeUnique(); + std::unique_ptr module = MakeUnique(); IRContext localContext(SPV_ENV_UNIVERSAL_1_2, std::move(module), spvtools::MessageConsumer()); @@ -98,14 +110,14 @@ TEST_F(IRContextTest, AllValidAfterPassNoChange) { built_analyses |= i; } - DummyPassPreservesNothing pass(opt::Pass::Status::SuccessWithoutChange); - opt::Pass::Status s = pass.Run(&localContext); - EXPECT_EQ(s, opt::Pass::Status::SuccessWithoutChange); + DummyPassPreservesNothing pass(Pass::Status::SuccessWithoutChange); + Pass::Status s = pass.Run(&localContext); + EXPECT_EQ(s, Pass::Status::SuccessWithoutChange); EXPECT_TRUE(localContext.AreAnalysesValid(built_analyses)); } TEST_F(IRContextTest, NoneValidAfterPassWithChange) { - std::unique_ptr module = MakeUnique(); + std::unique_ptr module = MakeUnique(); IRContext localContext(SPV_ENV_UNIVERSAL_1_2, std::move(module), spvtools::MessageConsumer()); @@ -114,9 +126,9 @@ TEST_F(IRContextTest, NoneValidAfterPassWithChange) { localContext.BuildInvalidAnalyses(i); } - DummyPassPreservesNothing pass(opt::Pass::Status::SuccessWithChange); - opt::Pass::Status s = pass.Run(&localContext); - EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange); + DummyPassPreservesNothing pass(Pass::Status::SuccessWithChange); + Pass::Status s = pass.Run(&localContext); + EXPECT_EQ(s, Pass::Status::SuccessWithChange); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { EXPECT_FALSE(localContext.AreAnalysesValid(i)); @@ -124,7 +136,7 @@ TEST_F(IRContextTest, NoneValidAfterPassWithChange) { } TEST_F(IRContextTest, AllPreservedAfterPassWithChange) { - std::unique_ptr module = MakeUnique(); + std::unique_ptr module = MakeUnique(); IRContext localContext(SPV_ENV_UNIVERSAL_1_2, std::move(module), spvtools::MessageConsumer()); @@ -133,9 +145,9 @@ TEST_F(IRContextTest, AllPreservedAfterPassWithChange) { localContext.BuildInvalidAnalyses(i); } - DummyPassPreservesAll pass(opt::Pass::Status::SuccessWithChange); - opt::Pass::Status s = pass.Run(&localContext); - EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange); + DummyPassPreservesAll pass(Pass::Status::SuccessWithChange); + Pass::Status s = pass.Run(&localContext); + EXPECT_EQ(s, Pass::Status::SuccessWithChange); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { EXPECT_TRUE(localContext.AreAnalysesValid(i)); @@ -143,7 +155,7 @@ TEST_F(IRContextTest, AllPreservedAfterPassWithChange) { } TEST_F(IRContextTest, PreserveFirstOnlyAfterPassWithChange) { - std::unique_ptr module = MakeUnique(); + std::unique_ptr module = MakeUnique(); IRContext localContext(SPV_ENV_UNIVERSAL_1_2, std::move(module), spvtools::MessageConsumer()); @@ -152,9 +164,9 @@ TEST_F(IRContextTest, PreserveFirstOnlyAfterPassWithChange) { localContext.BuildInvalidAnalyses(i); } - DummyPassPreservesFirst pass(opt::Pass::Status::SuccessWithChange); - opt::Pass::Status s = pass.Run(&localContext); - EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange); + DummyPassPreservesFirst pass(Pass::Status::SuccessWithChange); + Pass::Status s = pass.Run(&localContext); + EXPECT_EQ(s, Pass::Status::SuccessWithChange); EXPECT_TRUE(localContext.AreAnalysesValid(IRContext::kAnalysisBegin)); for (Analysis i = IRContext::kAnalysisBegin << 1; i < IRContext::kAnalysisEnd; i <<= 1) { @@ -184,7 +196,7 @@ TEST_F(IRContextTest, KillMemberName) { OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); // Build the decoration manager. @@ -212,4 +224,6 @@ TEST_F(IRContextTest, TakeNextUniqueIdIncrementing) { EXPECT_EQ(i, localContext.TakeNextUniqueId()); } -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/ir_loader_test.cpp b/3rdparty/spirv-tools/test/opt/ir_loader_test.cpp index 87e7e9777..ac5c52075 100644 --- a/3rdparty/spirv-tools/test/opt/ir_loader_test.cpp +++ b/3rdparty/spirv-tools/test/opt/ir_loader_test.cpp @@ -12,22 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include +#include +#include #include +#include +#include -#include "message.h" -#include "opt/build_module.h" -#include "opt/ir_context.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/ir_context.h" #include "spirv-tools/libspirv.hpp" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - void DoRoundTripCheck(const std::string& text) { SpirvTools t(SPV_ENV_UNIVERSAL_1_1); - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); ASSERT_NE(nullptr, context) << "Failed to assemble\n" << text; @@ -214,14 +217,14 @@ TEST(IrBuilder, OpUndefOutsideFunction) { // clang-format on SpirvTools t(SPV_ENV_UNIVERSAL_1_1); - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); ASSERT_NE(nullptr, context); const auto opundef_count = std::count_if( context->module()->types_values_begin(), context->module()->types_values_end(), - [](const ir::Instruction& inst) { return inst.opcode() == SpvOpUndef; }); + [](const Instruction& inst) { return inst.opcode() == SpvOpUndef; }); EXPECT_EQ(3, opundef_count); std::vector binary; @@ -317,28 +320,28 @@ TEST(IrBuilder, KeepModuleProcessedInRightPlace) { // Checks the given |error_message| is reported when trying to build a module // from the given |assembly|. void DoErrorMessageCheck(const std::string& assembly, - const std::string& error_message) { - auto consumer = [error_message](spv_message_level_t level, const char* source, - const spv_position_t& position, - const char* m) { - EXPECT_EQ(error_message, StringifyMessage(level, source, position, m)); + const std::string& error_message, uint32_t line_num) { + auto consumer = [error_message, line_num](spv_message_level_t, const char*, + const spv_position_t& position, + const char* m) { + EXPECT_EQ(error_message, m); + EXPECT_EQ(line_num, position.line); }; SpirvTools t(SPV_ENV_UNIVERSAL_1_1); - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, std::move(consumer), assembly); EXPECT_EQ(nullptr, context); } TEST(IrBuilder, FunctionInsideFunction) { DoErrorMessageCheck("%2 = OpFunction %1 None %3\n%5 = OpFunction %4 None %6", - "error: :2:0:0: function inside function"); + "function inside function", 2); } TEST(IrBuilder, MismatchOpFunctionEnd) { DoErrorMessageCheck("OpFunctionEnd", - "error: :1:0:0: OpFunctionEnd without " - "corresponding OpFunction"); + "OpFunctionEnd without corresponding OpFunction", 1); } TEST(IrBuilder, OpFunctionEndInsideBasicBlock) { @@ -346,12 +349,12 @@ TEST(IrBuilder, OpFunctionEndInsideBasicBlock) { "%2 = OpFunction %1 None %3\n" "%4 = OpLabel\n" "OpFunctionEnd", - "error: :3:0:0: OpFunctionEnd inside basic block"); + "OpFunctionEnd inside basic block", 3); } TEST(IrBuilder, BasicBlockOutsideFunction) { DoErrorMessageCheck("OpCapability Shader\n%1 = OpLabel", - "error: :2:0:0: OpLabel outside function"); + "OpLabel outside function", 2); } TEST(IrBuilder, OpLabelInsideBasicBlock) { @@ -359,26 +362,23 @@ TEST(IrBuilder, OpLabelInsideBasicBlock) { "%2 = OpFunction %1 None %3\n" "%4 = OpLabel\n" "%5 = OpLabel", - "error: :3:0:0: OpLabel inside basic block"); + "OpLabel inside basic block", 3); } TEST(IrBuilder, TerminatorOutsideFunction) { - DoErrorMessageCheck( - "OpReturn", - "error: :1:0:0: terminator instruction outside function"); + DoErrorMessageCheck("OpReturn", "terminator instruction outside function", 1); } TEST(IrBuilder, TerminatorOutsideBasicBlock) { DoErrorMessageCheck("%2 = OpFunction %1 None %3\nOpReturn", - "error: :2:0:0: terminator instruction " - "outside basic block"); + "terminator instruction outside basic block", 2); } TEST(IrBuilder, NotAllowedInstAppearingInFunction) { DoErrorMessageCheck("%2 = OpFunction %1 None %3\n%5 = OpVariable %4 Function", - "error: :2:0:0: Non-OpFunctionParameter " - "(opcode: 59) found inside function but outside basic " - "block"); + "Non-OpFunctionParameter (opcode: 59) found inside " + "function but outside basic block", + 2); } TEST(IrBuilder, UniqueIds) { @@ -436,14 +436,16 @@ TEST(IrBuilder, UniqueIds) { "OpFunctionEnd\n"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); ASSERT_NE(nullptr, context); std::unordered_set ids; - context->module()->ForEachInst([&ids](const ir::Instruction* inst) { + context->module()->ForEachInst([&ids](const Instruction* inst) { EXPECT_TRUE(ids.insert(inst->unique_id()).second); }); } -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/iterator_test.cpp b/3rdparty/spirv-tools/test/opt/iterator_test.cpp index 5afe88c15..d61bc1ab8 100644 --- a/3rdparty/spirv-tools/test/opt/iterator_test.cpp +++ b/3rdparty/spirv-tools/test/opt/iterator_test.cpp @@ -17,12 +17,13 @@ #include "gmock/gmock.h" -#include "opt/iterator.h" -#include "opt/make_unique.h" +#include "source/opt/iterator.h" +#include "source/util/make_unique.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::ContainerEq; TEST(Iterator, IncrementDeref) { @@ -32,8 +33,8 @@ TEST(Iterator, IncrementDeref) { data.emplace_back(new int(i)); } - ir::UptrVectorIterator it(&data, data.begin()); - ir::UptrVectorIterator end(&data, data.end()); + UptrVectorIterator it(&data, data.begin()); + UptrVectorIterator end(&data, data.end()); EXPECT_EQ(*data[0], *it); for (int i = 1; i < count; ++i) { @@ -50,8 +51,8 @@ TEST(Iterator, DecrementDeref) { data.emplace_back(new int(i)); } - ir::UptrVectorIterator begin(&data, data.begin()); - ir::UptrVectorIterator it(&data, data.end()); + UptrVectorIterator begin(&data, data.begin()); + UptrVectorIterator it(&data, data.end()); for (int i = count - 1; i >= 0; --i) { EXPECT_NE(begin, it); @@ -67,8 +68,8 @@ TEST(Iterator, PostIncrementDeref) { data.emplace_back(new int(i)); } - ir::UptrVectorIterator it(&data, data.begin()); - ir::UptrVectorIterator end(&data, data.end()); + UptrVectorIterator it(&data, data.begin()); + UptrVectorIterator end(&data, data.end()); for (int i = 0; i < count; ++i) { EXPECT_NE(end, it); @@ -84,9 +85,9 @@ TEST(Iterator, PostDecrementDeref) { data.emplace_back(new int(i)); } - ir::UptrVectorIterator begin(&data, data.begin()); - ir::UptrVectorIterator end(&data, data.end()); - ir::UptrVectorIterator it(&data, data.end()); + UptrVectorIterator begin(&data, data.begin()); + UptrVectorIterator end(&data, data.end()); + UptrVectorIterator it(&data, data.end()); EXPECT_EQ(end, it--); for (int i = count - 1; i >= 1; --i) { @@ -103,7 +104,7 @@ TEST(Iterator, Access) { data.emplace_back(new int(i)); } - ir::UptrVectorIterator it(&data, data.begin()); + UptrVectorIterator it(&data, data.begin()); for (int i = 0; i < count; ++i) EXPECT_EQ(*data[i], it[i]); } @@ -115,8 +116,8 @@ TEST(Iterator, Comparison) { data.emplace_back(new int(i)); } - ir::UptrVectorIterator it(&data, data.begin()); - ir::UptrVectorIterator end(&data, data.end()); + UptrVectorIterator it(&data, data.begin()); + UptrVectorIterator end(&data, data.end()); for (int i = 0; i < count; ++i, ++it) EXPECT_TRUE(it < end); EXPECT_EQ(end, it); @@ -136,7 +137,7 @@ TEST(Iterator, InsertBeginEnd) { // Insert at the beginning expected.insert(expected.begin(), -100); - ir::UptrVectorIterator begin(&data, data.begin()); + UptrVectorIterator begin(&data, data.begin()); auto insert_point = begin.InsertBefore(MakeUnique(-100)); for (int i = 0; i < count + 1; ++i) { actual.push_back(*(insert_point++)); @@ -147,13 +148,13 @@ TEST(Iterator, InsertBeginEnd) { expected.push_back(-42); expected.push_back(-36); expected.push_back(-77); - ir::UptrVectorIterator end(&data, data.end()); + UptrVectorIterator end(&data, data.end()); end = end.InsertBefore(MakeUnique(-77)); end = end.InsertBefore(MakeUnique(-36)); end = end.InsertBefore(MakeUnique(-42)); actual.clear(); - begin = ir::UptrVectorIterator(&data, data.begin()); + begin = UptrVectorIterator(&data, data.begin()); for (int i = 0; i < count + 4; ++i) { actual.push_back(*(begin++)); } @@ -176,11 +177,11 @@ TEST(Iterator, InsertMiddle) { expected.insert(expected.begin() + insert_pos, -100); expected.insert(expected.begin() + insert_pos, -42); - ir::UptrVectorIterator it(&data, data.begin()); + UptrVectorIterator it(&data, data.begin()); for (int i = 0; i < insert_pos; ++i) ++it; it = it.InsertBefore(MakeUnique(-100)); it = it.InsertBefore(MakeUnique(-42)); - auto begin = ir::UptrVectorIterator(&data, data.begin()); + auto begin = UptrVectorIterator(&data, data.begin()); for (int i = 0; i < count + 2; ++i) { actual.push_back(*(begin++)); } @@ -196,9 +197,9 @@ TEST(IteratorRange, Interface) { data.emplace_back(new uint32_t(i)); } - auto b = ir::UptrVectorIterator(&data, data.begin()); - auto e = ir::UptrVectorIterator(&data, data.end()); - auto range = ir::IteratorRange(b, e); + auto b = UptrVectorIterator(&data, data.begin()); + auto e = UptrVectorIterator(&data, data.end()); + auto range = IteratorRange(b, e); EXPECT_EQ(b, range.begin()); EXPECT_EQ(e, range.end()); @@ -214,4 +215,53 @@ TEST(IteratorRange, Interface) { EXPECT_EQ(count, range.size()); } -} // anonymous namespace +TEST(Iterator, FilterIterator) { + struct Placeholder { + int val; + }; + std::vector data = {{1}, {2}, {3}, {4}, {5}, + {6}, {7}, {8}, {9}, {10}}; + + // Predicate to only consider odd values. + struct Predicate { + bool operator()(const Placeholder& data) { return data.val % 2; } + }; + Predicate pred; + + auto filter_range = MakeFilterIteratorRange(data.begin(), data.end(), pred); + + EXPECT_EQ(filter_range.begin().Get(), data.begin()); + EXPECT_EQ(filter_range.end(), filter_range.begin().GetEnd()); + + for (Placeholder& data : filter_range) { + EXPECT_EQ(data.val % 2, 1); + } + + for (auto it = filter_range.begin(); it != filter_range.end(); it++) { + EXPECT_EQ(it->val % 2, 1); + EXPECT_EQ((*it).val % 2, 1); + } + + for (auto it = filter_range.begin(); it != filter_range.end(); ++it) { + EXPECT_EQ(it->val % 2, 1); + EXPECT_EQ((*it).val % 2, 1); + } + + EXPECT_EQ(MakeFilterIterator(data.begin(), data.end(), pred).Get(), + data.begin()); + EXPECT_EQ(MakeFilterIterator(data.end(), data.end(), pred).Get(), data.end()); + EXPECT_EQ(MakeFilterIterator(data.begin(), data.end(), pred).GetEnd(), + MakeFilterIterator(data.end(), data.end(), pred)); + EXPECT_NE(MakeFilterIterator(data.begin(), data.end(), pred), + MakeFilterIterator(data.end(), data.end(), pred)); + + // Empty range: no values satisfies the predicate. + auto empty_range = MakeFilterIteratorRange( + data.begin(), data.end(), + [](const Placeholder& data) { return data.val > 10; }); + EXPECT_EQ(empty_range.begin(), empty_range.end()); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/line_debug_info_test.cpp b/3rdparty/spirv-tools/test/opt/line_debug_info_test.cpp index 2bb794864..6a20a0136 100644 --- a/3rdparty/spirv-tools/test/opt/line_debug_info_test.cpp +++ b/3rdparty/spirv-tools/test/opt/line_debug_info_test.cpp @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - // A pass turning all none debug line instructions into Nop. -class NopifyPass : public opt::Pass { +class NopifyPass : public Pass { public: const char* name() const override { return "NopifyPass"; } - Status Process(ir::IRContext* irContext) override { + Status Process() override { bool modified = false; - irContext->module()->ForEachInst( - [&modified](ir::Instruction* inst) { + context()->module()->ForEachInst( + [&modified](Instruction* inst) { inst->ToNop(); modified = true; }, @@ -108,4 +108,6 @@ TEST_F(PassTestForLineDebugInfo, KeepLineDebugInfo) { /* skip_nop = */ true); } -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/local_access_chain_convert_test.cpp b/3rdparty/spirv-tools/test/opt/local_access_chain_convert_test.cpp index b559a814e..cb3572161 100644 --- a/3rdparty/spirv-tools/test/opt/local_access_chain_convert_test.cpp +++ b/3rdparty/spirv-tools/test/opt/local_access_chain_convert_test.cpp @@ -13,15 +13,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using LocalAccessChainConvertTest = PassTest<::testing::Test>; +#ifdef SPIRV_EFFCEE + TEST_F(LocalAccessChainConvertTest, StructOfVecsOfFloatConverted) { // #version 140 // @@ -66,38 +70,18 @@ OpName %gl_FragColor "gl_FragColor" %_ptr_Function_v4float = OpTypePointer Function %v4float %_ptr_Output_v4float = OpTypePointer Output %v4float %gl_FragColor = OpVariable %_ptr_Output_v4float Output -)"; - - const std::string predefs_after = - R"(OpCapability Shader -%1 = OpExtInstImport "GLSL.std.450" -OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor -OpExecutionMode %main OriginUpperLeft -OpSource GLSL 140 -OpName %main "main" -OpName %S_t "S_t" -OpMemberName %S_t 0 "v0" -OpMemberName %S_t 1 "v1" -OpName %s0 "s0" -OpName %BaseColor "BaseColor" -OpName %gl_FragColor "gl_FragColor" -%void = OpTypeVoid -%8 = OpTypeFunction %void -%float = OpTypeFloat 32 -%v4float = OpTypeVector %float 4 -%S_t = OpTypeStruct %v4float %v4float -%_ptr_Function_S_t = OpTypePointer Function %S_t -%int = OpTypeInt 32 1 -%_ptr_Input_v4float = OpTypePointer Input %v4float -%BaseColor = OpVariable %_ptr_Input_v4float Input -%_ptr_Function_v4float = OpTypePointer Function %v4float -%_ptr_Output_v4float = OpTypePointer Output %v4float -%gl_FragColor = OpVariable %_ptr_Output_v4float Output )"; const std::string before = - R"(%main = OpFunction %void None %8 + R"( +; CHECK: [[st_id:%\w+]] = OpLoad %v4float %BaseColor +; CHECK: [[ld1:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex1:%\w+]] = OpCompositeInsert %S_t [[st_id]] [[ld1]] 1 +; CHECK: OpStore %s0 [[ex1]] +; CHECK: [[ld2:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex2:%\w+]] = OpCompositeExtract %v4float [[ld2]] 1 +; CHECK: OpStore %gl_FragColor [[ex2]] +%main = OpFunction %void None %8 %17 = OpLabel %s0 = OpVariable %_ptr_Function_S_t Function %18 = OpLoad %v4float %BaseColor @@ -110,23 +94,8 @@ OpReturn OpFunctionEnd )"; - const std::string after = - R"(%main = OpFunction %void None %8 -%17 = OpLabel -%s0 = OpVariable %_ptr_Function_S_t Function -%18 = OpLoad %v4float %BaseColor -%22 = OpLoad %S_t %s0 -%23 = OpCompositeInsert %S_t %18 %22 1 -OpStore %s0 %23 -%24 = OpLoad %S_t %s0 -%25 = OpCompositeExtract %v4float %24 1 -OpStore %gl_FragColor %25 -OpReturn -OpFunctionEnd -)"; - - SinglePassRunAndCheck( - predefs_before + before, predefs_after + after, true, true); + SinglePassRunAndMatch(predefs_before + before, + true); } TEST_F(LocalAccessChainConvertTest, InBoundsAccessChainsConverted) { @@ -173,38 +142,18 @@ OpName %gl_FragColor "gl_FragColor" %_ptr_Function_v4float = OpTypePointer Function %v4float %_ptr_Output_v4float = OpTypePointer Output %v4float %gl_FragColor = OpVariable %_ptr_Output_v4float Output -)"; - - const std::string predefs_after = - R"(OpCapability Shader -%1 = OpExtInstImport "GLSL.std.450" -OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor -OpExecutionMode %main OriginUpperLeft -OpSource GLSL 140 -OpName %main "main" -OpName %S_t "S_t" -OpMemberName %S_t 0 "v0" -OpMemberName %S_t 1 "v1" -OpName %s0 "s0" -OpName %BaseColor "BaseColor" -OpName %gl_FragColor "gl_FragColor" -%void = OpTypeVoid -%8 = OpTypeFunction %void -%float = OpTypeFloat 32 -%v4float = OpTypeVector %float 4 -%S_t = OpTypeStruct %v4float %v4float -%_ptr_Function_S_t = OpTypePointer Function %S_t -%int = OpTypeInt 32 1 -%_ptr_Input_v4float = OpTypePointer Input %v4float -%BaseColor = OpVariable %_ptr_Input_v4float Input -%_ptr_Function_v4float = OpTypePointer Function %v4float -%_ptr_Output_v4float = OpTypePointer Output %v4float -%gl_FragColor = OpVariable %_ptr_Output_v4float Output )"; const std::string before = - R"(%main = OpFunction %void None %8 + R"( +; CHECK: [[st_id:%\w+]] = OpLoad %v4float %BaseColor +; CHECK: [[ld1:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex1:%\w+]] = OpCompositeInsert %S_t [[st_id]] [[ld1]] 1 +; CHECK: OpStore %s0 [[ex1]] +; CHECK: [[ld2:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex2:%\w+]] = OpCompositeExtract %v4float [[ld2]] 1 +; CHECK: OpStore %gl_FragColor [[ex2]] +%main = OpFunction %void None %8 %17 = OpLabel %s0 = OpVariable %_ptr_Function_S_t Function %18 = OpLoad %v4float %BaseColor @@ -217,23 +166,8 @@ OpReturn OpFunctionEnd )"; - const std::string after = - R"(%main = OpFunction %void None %8 -%17 = OpLabel -%s0 = OpVariable %_ptr_Function_S_t Function -%18 = OpLoad %v4float %BaseColor -%22 = OpLoad %S_t %s0 -%23 = OpCompositeInsert %S_t %18 %22 1 -OpStore %s0 %23 -%24 = OpLoad %S_t %s0 -%25 = OpCompositeExtract %v4float %24 1 -OpStore %gl_FragColor %25 -OpReturn -OpFunctionEnd -)"; - - SinglePassRunAndCheck( - predefs_before + before, predefs_after + after, true, true); + SinglePassRunAndMatch(predefs_before + before, + true); } TEST_F(LocalAccessChainConvertTest, TwoUsesofSingleChainConverted) { @@ -280,38 +214,18 @@ OpName %gl_FragColor "gl_FragColor" %_ptr_Function_v4float = OpTypePointer Function %v4float %_ptr_Output_v4float = OpTypePointer Output %v4float %gl_FragColor = OpVariable %_ptr_Output_v4float Output -)"; - - const std::string predefs_after = - R"(OpCapability Shader -%1 = OpExtInstImport "GLSL.std.450" -OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor -OpExecutionMode %main OriginUpperLeft -OpSource GLSL 140 -OpName %main "main" -OpName %S_t "S_t" -OpMemberName %S_t 0 "v0" -OpMemberName %S_t 1 "v1" -OpName %s0 "s0" -OpName %BaseColor "BaseColor" -OpName %gl_FragColor "gl_FragColor" -%void = OpTypeVoid -%8 = OpTypeFunction %void -%float = OpTypeFloat 32 -%v4float = OpTypeVector %float 4 -%S_t = OpTypeStruct %v4float %v4float -%_ptr_Function_S_t = OpTypePointer Function %S_t -%int = OpTypeInt 32 1 -%_ptr_Input_v4float = OpTypePointer Input %v4float -%BaseColor = OpVariable %_ptr_Input_v4float Input -%_ptr_Function_v4float = OpTypePointer Function %v4float -%_ptr_Output_v4float = OpTypePointer Output %v4float -%gl_FragColor = OpVariable %_ptr_Output_v4float Output )"; const std::string before = - R"(%main = OpFunction %void None %8 + R"( +; CHECK: [[st_id:%\w+]] = OpLoad %v4float %BaseColor +; CHECK: [[ld1:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex1:%\w+]] = OpCompositeInsert %S_t [[st_id]] [[ld1]] 1 +; CHECK: OpStore %s0 [[ex1]] +; CHECK: [[ld2:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex2:%\w+]] = OpCompositeExtract %v4float [[ld2]] 1 +; CHECK: OpStore %gl_FragColor [[ex2]] +%main = OpFunction %void None %8 %17 = OpLabel %s0 = OpVariable %_ptr_Function_S_t Function %18 = OpLoad %v4float %BaseColor @@ -323,23 +237,8 @@ OpReturn OpFunctionEnd )"; - const std::string after = - R"(%main = OpFunction %void None %8 -%17 = OpLabel -%s0 = OpVariable %_ptr_Function_S_t Function -%18 = OpLoad %v4float %BaseColor -%21 = OpLoad %S_t %s0 -%22 = OpCompositeInsert %S_t %18 %21 1 -OpStore %s0 %22 -%23 = OpLoad %S_t %s0 -%24 = OpCompositeExtract %v4float %23 1 -OpStore %gl_FragColor %24 -OpReturn -OpFunctionEnd -)"; - - SinglePassRunAndCheck( - predefs_before + before, predefs_after + after, true, true); + SinglePassRunAndMatch(predefs_before + before, + true); } TEST_F(LocalAccessChainConvertTest, OpaqueConverted) { @@ -347,7 +246,8 @@ TEST_F(LocalAccessChainConvertTest, OpaqueConverted) { // at the moment const std::string predefs = - R"(OpCapability Shader + R"( +OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %main "main" %outColor %texCoords @@ -390,9 +290,15 @@ OpDecorate %sampler15 DescriptorSet 0 )"; const std::string before = - R"(%main = OpFunction %void None %12 + R"( +; CHECK: [[l1:%\w+]] = OpLoad %S_t %param +; CHECK: [[e1:%\w+]] = OpCompositeExtract {{%\w+}} [[l1]] 2 +; CHECK: [[l2:%\w+]] = OpLoad %S_t %param +; CHECK: [[e2:%\w+]] = OpCompositeExtract {{%\w+}} [[l2]] 0 +; CHECK: OpImageSampleImplicitLod {{%\w+}} [[e1]] [[e2]] +%main = OpFunction %void None %12 %28 = OpLabel -%s0 = OpVariable %_ptr_Function_S_t Function +%s0 = OpVariable %_ptr_Function_S_t Function %param = OpVariable %_ptr_Function_S_t Function %29 = OpLoad %v2float %texCoords %30 = OpAccessChain %_ptr_Function_v2float %s0 %int_0 @@ -400,7 +306,7 @@ OpStore %30 %29 %31 = OpLoad %18 %sampler15 %32 = OpAccessChain %_ptr_Function_18 %s0 %int_2 OpStore %32 %31 -%33 = OpLoad %S_t %s0 +%33 = OpLoad %S_t %s0 OpStore %param %33 %34 = OpAccessChain %_ptr_Function_18 %param %int_2 %35 = OpLoad %18 %34 @@ -410,31 +316,6 @@ OpStore %param %33 OpStore %outColor %38 OpReturn OpFunctionEnd -)"; - - const std::string after = - R"(%main = OpFunction %void None %12 -%28 = OpLabel -%s0 = OpVariable %_ptr_Function_S_t Function -%param = OpVariable %_ptr_Function_S_t Function -%29 = OpLoad %v2float %texCoords -%45 = OpLoad %S_t %s0 -%46 = OpCompositeInsert %S_t %29 %45 0 -OpStore %s0 %46 -%31 = OpLoad %18 %sampler15 -%47 = OpLoad %S_t %s0 -%48 = OpCompositeInsert %S_t %31 %47 2 -OpStore %s0 %48 -%33 = OpLoad %S_t %s0 -OpStore %param %33 -%49 = OpLoad %S_t %param -%50 = OpCompositeExtract %18 %49 2 -%51 = OpLoad %S_t %param -%52 = OpCompositeExtract %v2float %51 0 -%38 = OpImageSampleImplicitLod %v4float %50 %52 -OpStore %outColor %38 -OpReturn -OpFunctionEnd )"; const std::string remain = @@ -451,8 +332,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before + remain, predefs + after + remain, true, true); + SinglePassRunAndMatch(predefs + before + remain, + true); } TEST_F(LocalAccessChainConvertTest, NestedStructsConverted) { @@ -507,41 +388,18 @@ OpName %gl_FragColor "gl_FragColor" %_ptr_Function_v4float = OpTypePointer Function %v4float %_ptr_Output_v4float = OpTypePointer Output %v4float %gl_FragColor = OpVariable %_ptr_Output_v4float Output -)"; - - const std::string predefs_after = - R"(OpCapability Shader -%1 = OpExtInstImport "GLSL.std.450" -OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor -OpExecutionMode %main OriginUpperLeft -OpSource GLSL 140 -OpName %main "main" -OpName %S1_t "S1_t" -OpMemberName %S1_t 0 "v1" -OpName %S2_t "S2_t" -OpMemberName %S2_t 0 "v2" -OpMemberName %S2_t 1 "s1" -OpName %s2 "s2" -OpName %BaseColor "BaseColor" -OpName %gl_FragColor "gl_FragColor" -%void = OpTypeVoid -%9 = OpTypeFunction %void -%float = OpTypeFloat 32 -%v4float = OpTypeVector %float 4 -%S1_t = OpTypeStruct %v4float -%S2_t = OpTypeStruct %v4float %S1_t -%_ptr_Function_S2_t = OpTypePointer Function %S2_t -%int = OpTypeInt 32 1 -%_ptr_Input_v4float = OpTypePointer Input %v4float -%BaseColor = OpVariable %_ptr_Input_v4float Input -%_ptr_Function_v4float = OpTypePointer Function %v4float -%_ptr_Output_v4float = OpTypePointer Output %v4float -%gl_FragColor = OpVariable %_ptr_Output_v4float Output )"; const std::string before = - R"(%main = OpFunction %void None %9 + R"( +; CHECK: [[st_id:%\w+]] = OpLoad %v4float %BaseColor +; CHECK: [[ld1:%\w+]] = OpLoad %S2_t %s2 +; CHECK: [[ex1:%\w+]] = OpCompositeInsert %S2_t [[st_id]] [[ld1]] 1 0 +; CHECK: OpStore %s2 [[ex1]] +; CHECK: [[ld2:%\w+]] = OpLoad %S2_t %s2 +; CHECK: [[ex2:%\w+]] = OpCompositeExtract %v4float [[ld2]] 1 0 +; CHECK: OpStore %gl_FragColor [[ex2]] +%main = OpFunction %void None %9 %19 = OpLabel %s2 = OpVariable %_ptr_Function_S2_t Function %20 = OpLoad %v4float %BaseColor @@ -554,25 +412,221 @@ OpReturn OpFunctionEnd )"; - const std::string after = - R"(%main = OpFunction %void None %9 -%19 = OpLabel -%s2 = OpVariable %_ptr_Function_S2_t Function -%20 = OpLoad %v4float %BaseColor -%24 = OpLoad %S2_t %s2 -%25 = OpCompositeInsert %S2_t %20 %24 1 0 -OpStore %s2 %25 -%26 = OpLoad %S2_t %s2 -%27 = OpCompositeExtract %v4float %26 1 0 -OpStore %gl_FragColor %27 + SinglePassRunAndMatch(predefs_before + before, + true); +} + +TEST_F(LocalAccessChainConvertTest, SomeAccessChainsHaveNoUse) { + // Based on HLSL source code: + // struct S { + // float f; + // }; + + // float main(float input : A) : B { + // S local = { input }; + // return local.f; + // } + + const std::string predefs = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" %in_var_A %out_var_B +OpName %main "main" +OpName %in_var_A "in.var.A" +OpName %out_var_B "out.var.B" +OpName %S "S" +OpName %local "local" +%int = OpTypeInt 32 1 +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Input_float = OpTypePointer Input %float +%_ptr_Output_float = OpTypePointer Output %float +%S = OpTypeStruct %float +%_ptr_Function_S = OpTypePointer Function %S +%int_0 = OpConstant %int 0 +%in_var_A = OpVariable %_ptr_Input_float Input +%out_var_B = OpVariable %_ptr_Output_float Output +%main = OpFunction %void None %8 +%15 = OpLabel +%local = OpVariable %_ptr_Function_S Function +%16 = OpLoad %float %in_var_A +%17 = OpCompositeConstruct %S %16 +OpStore %local %17 +)"; + + const std::string before = + R"( +; CHECK: [[ld:%\w+]] = OpLoad %S %local +; CHECK: [[ex:%\w+]] = OpCompositeExtract %float [[ld]] 0 +; CHECK: OpStore %out_var_B [[ex]] +%18 = OpAccessChain %_ptr_Function_float %local %int_0 +%19 = OpAccessChain %_ptr_Function_float %local %int_0 +%20 = OpLoad %float %18 +OpStore %out_var_B %20 OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs_before + before, predefs_after + after, true, true); + SinglePassRunAndMatch(predefs + before, true); } +TEST_F(LocalAccessChainConvertTest, + StructOfVecsOfFloatConvertedWithDecorationOnLoad) { + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // gl_FragColor = s0.v1; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +OpDecorate %21 RelaxedPrecision +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"( +; CHECK: OpDecorate +; CHECK: OpDecorate [[ld2:%\w+]] RelaxedPrecision +; CHECK-NOT: OpDecorate +; CHECK: [[st_id:%\w+]] = OpLoad %v4float %BaseColor +; CHECK: [[ld1:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ins:%\w+]] = OpCompositeInsert %S_t [[st_id]] [[ld1]] 1 +; CHECK: OpStore %s0 [[ins]] +; CHECK: [[ld2]] = OpLoad %S_t %s0 +; CHECK: [[ex2:%\w+]] = OpCompositeExtract %v4float [[ld2]] 1 +; CHECK: OpStore %gl_FragColor [[ex2]] +%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%19 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +OpStore %19 %18 +%20 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +%21 = OpLoad %v4float %20 +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs_before + before, + true); +} + +TEST_F(LocalAccessChainConvertTest, + StructOfVecsOfFloatConvertedWithDecorationOnStore) { + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // gl_FragColor = s0.v1; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +OpDecorate %s0 RelaxedPrecision +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"( +; CHECK: OpDecorate +; CHECK: OpDecorate [[ld1:%\w+]] RelaxedPrecision +; CHECK: OpDecorate [[ins:%\w+]] RelaxedPrecision +; CHECK-NOT: OpDecorate +; CHECK: [[st_id:%\w+]] = OpLoad %v4float %BaseColor +; CHECK: [[ld1]] = OpLoad %S_t %s0 +; CHECK: [[ins]] = OpCompositeInsert %S_t [[st_id]] [[ld1]] 1 +; CHECK: OpStore %s0 [[ins]] +; CHECK: [[ld2:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex2:%\w+]] = OpCompositeExtract %v4float [[ld2]] 1 +; CHECK: OpStore %gl_FragColor [[ex2]] +%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%19 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +OpStore %19 %18 +%20 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +%21 = OpLoad %v4float %20 +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs_before + before, + true); +} +#endif // SPIRV_EFFCEE + TEST_F(LocalAccessChainConvertTest, DynamicallyIndexedVarNotConverted) { // #version 140 // @@ -645,69 +699,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, - false, true); -} - -TEST_F(LocalAccessChainConvertTest, SomeAccessChainsHaveNoUse) { - // Based on HLSL source code: - // struct S { - // float f; - // }; - - // float main(float input : A) : B { - // S local = { input }; - // return local.f; - // } - - const std::string predefs = R"(OpCapability Shader -OpMemoryModel Logical GLSL450 -OpEntryPoint Vertex %main "main" %in_var_A %out_var_B -OpName %main "main" -OpName %in_var_A "in.var.A" -OpName %out_var_B "out.var.B" -OpName %S "S" -OpName %local "local" -%int = OpTypeInt 32 1 -%void = OpTypeVoid -%8 = OpTypeFunction %void -%float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Input_float = OpTypePointer Input %float -%_ptr_Output_float = OpTypePointer Output %float -%S = OpTypeStruct %float -%_ptr_Function_S = OpTypePointer Function %S -%int_0 = OpConstant %int 0 -%in_var_A = OpVariable %_ptr_Input_float Input -%out_var_B = OpVariable %_ptr_Output_float Output -%main = OpFunction %void None %8 -%15 = OpLabel -%local = OpVariable %_ptr_Function_S Function -%16 = OpLoad %float %in_var_A -%17 = OpCompositeConstruct %S %16 -OpStore %local %17 -)"; - - const std::string before = - R"(%18 = OpAccessChain %_ptr_Function_float %local %int_0 -%19 = OpAccessChain %_ptr_Function_float %local %int_0 -%20 = OpLoad %float %18 -OpStore %out_var_B %20 -OpReturn -OpFunctionEnd -)"; - - const std::string after = - R"(%19 = OpAccessChain %_ptr_Function_float %local %int_0 -%21 = OpLoad %S %local -%22 = OpCompositeExtract %float %21 0 -OpStore %out_var_B %22 -OpReturn -OpFunctionEnd -)"; - - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(assembly, assembly, false, + true); } // TODO(greg-lunarg): Add tests to verify handling of these cases: @@ -719,4 +712,6 @@ OpFunctionEnd // OpInBoundsAccessChain // Others? -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/local_redundancy_elimination_test.cpp b/3rdparty/spirv-tools/test/opt/local_redundancy_elimination_test.cpp index 70ccf7bde..bdaafb85f 100644 --- a/3rdparty/spirv-tools/test/opt/local_redundancy_elimination_test.cpp +++ b/3rdparty/spirv-tools/test/opt/local_redundancy_elimination_test.cpp @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "opt/value_number_table.h" +#include -#include "assembly_builder.h" #include "gmock/gmock.h" -#include "opt/build_module.h" -#include "pass_fixture.h" -#include "pass_utils.h" +#include "source/opt/build_module.h" +#include "source/opt/value_number_table.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using ::testing::HasSubstr; using ::testing::MatchesRegex; - using LocalRedundancyEliminationTest = PassTest<::testing::Test>; #ifdef SPIRV_EFFCEE @@ -54,7 +54,7 @@ TEST_F(LocalRedundancyEliminationTest, RemoveRedundantAdd) { OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } // Make sure we keep instruction that are different, but look similar. @@ -85,7 +85,7 @@ TEST_F(LocalRedundancyEliminationTest, KeepDifferentAdd) { OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } // This test is check that the values are being propagated properly, and that @@ -123,7 +123,7 @@ TEST_F(LocalRedundancyEliminationTest, RemoveMultipleInstructions) { OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } // Redundant instructions in different blocks should be kept. @@ -152,7 +152,10 @@ TEST_F(LocalRedundancyEliminationTest, KeepInstructionsInDifferentBlocks) { OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } #endif -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/local_single_block_elim.cpp b/3rdparty/spirv-tools/test/opt/local_single_block_elim.cpp index 870103509..da7540e6b 100644 --- a/3rdparty/spirv-tools/test/opt/local_single_block_elim.cpp +++ b/3rdparty/spirv-tools/test/opt/local_single_block_elim.cpp @@ -13,13 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using LocalSingleBlockLoadStoreElimTest = PassTest<::testing::Test>; TEST_F(LocalSingleBlockLoadStoreElimTest, SimpleStoreLoadElim) { @@ -73,13 +75,12 @@ OpFunctionEnd %v = OpVariable %_ptr_Function_v4float Function %14 = OpLoad %v4float %BaseColor OpStore %v %14 -%15 = OpLoad %v4float %v OpStore %gl_FragColor %14 OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs_before + before, predefs_before + after, true, true); } @@ -174,105 +175,191 @@ OpBranch %29 %31 = OpLoad %v4float %v %32 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_0 OpStore %32 %31 -%33 = OpLoad %v4float %v %34 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_1 OpStore %34 %31 OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs + before, predefs + after, true, true); } +TEST_F(LocalSingleBlockLoadStoreElimTest, StoreStoreElim) { + // + // Note first store to v is eliminated + // + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // void main() + // { + // vec4 v = BaseColor; + // v = v * 0.5; + // OutColor = v; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%float_0_5 = OpConstant %float 0.5 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%14 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%15 = OpLoad %v4float %BaseColor +OpStore %v %15 +%16 = OpLoad %v4float %v +%17 = OpVectorTimesScalar %v4float %16 %float_0_5 +OpStore %v %17 +%18 = OpLoad %v4float %v +OpStore %OutColor %18 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%14 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%15 = OpLoad %v4float %BaseColor +%17 = OpVectorTimesScalar %v4float %15 %float_0_5 +OpStore %v %17 +OpStore %OutColor %17 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs_before + before, predefs_before + after, true, true); +} + TEST_F(LocalSingleBlockLoadStoreElimTest, NoStoreElimIfInterveningAccessChainLoad) { // - // Note that even though the Load to %v is eliminated, the Store to %v - // is not eliminated due to the following access chain reference. + // Note the first Store to %v is not eliminated due to the following access + // chain reference. // - // #version 140 + // #version 450 // - // in vec4 BaseColor; - // flat in int Idx; + // layout(location = 0) in vec4 BaseColor0; + // layout(location = 1) in vec4 BaseColor1; + // layout(location = 2) flat in int Idx; + // layout(location = 0) out vec4 OutColor; // // void main() // { - // vec4 v = BaseColor; + // vec4 v = BaseColor0; // float f = v[Idx]; - // gl_FragColor = v/f; + // v = BaseColor1 + vec4(0.1); + // OutColor = v/f; // } const std::string predefs = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %main "main" %BaseColor %Idx %gl_FragColor +OpEntryPoint Fragment %main "main" %BaseColor0 %Idx %BaseColor1 %OutColor OpExecutionMode %main OriginUpperLeft -OpSource GLSL 140 +OpSource GLSL 450 OpName %main "main" OpName %v "v" -OpName %BaseColor "BaseColor" +OpName %BaseColor0 "BaseColor0" OpName %f "f" OpName %Idx "Idx" -OpName %gl_FragColor "gl_FragColor" +OpName %BaseColor1 "BaseColor1" +OpName %OutColor "OutColor" +OpDecorate %BaseColor0 Location 0 OpDecorate %Idx Flat +OpDecorate %Idx Location 2 +OpDecorate %BaseColor1 Location 1 +OpDecorate %OutColor Location 0 %void = OpTypeVoid -%9 = OpTypeFunction %void +%10 = OpTypeFunction %void %float = OpTypeFloat 32 %v4float = OpTypeVector %float 4 %_ptr_Function_v4float = OpTypePointer Function %v4float %_ptr_Input_v4float = OpTypePointer Input %v4float -%BaseColor = OpVariable %_ptr_Input_v4float Input +%BaseColor0 = OpVariable %_ptr_Input_v4float Input %_ptr_Function_float = OpTypePointer Function %float %int = OpTypeInt 32 1 %_ptr_Input_int = OpTypePointer Input %int %Idx = OpVariable %_ptr_Input_int Input +%BaseColor1 = OpVariable %_ptr_Input_v4float Input +%float_0_100000001 = OpConstant %float 0.100000001 +%19 = OpConstantComposite %v4float %float_0_100000001 %float_0_100000001 %float_0_100000001 %float_0_100000001 %_ptr_Output_v4float = OpTypePointer Output %v4float -%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%OutColor = OpVariable %_ptr_Output_v4float Output )"; const std::string before = - R"(%main = OpFunction %void None %9 -%18 = OpLabel + R"(%main = OpFunction %void None %10 +%21 = OpLabel %v = OpVariable %_ptr_Function_v4float Function %f = OpVariable %_ptr_Function_float Function -%19 = OpLoad %v4float %BaseColor -OpStore %v %19 -%20 = OpLoad %int %Idx -%21 = OpAccessChain %_ptr_Function_float %v %20 -%22 = OpLoad %float %21 -OpStore %f %22 -%23 = OpLoad %v4float %v -%24 = OpLoad %float %f -%25 = OpCompositeConstruct %v4float %24 %24 %24 %24 -%26 = OpFDiv %v4float %23 %25 -OpStore %gl_FragColor %26 +%22 = OpLoad %v4float %BaseColor0 +OpStore %v %22 +%23 = OpLoad %int %Idx +%24 = OpAccessChain %_ptr_Function_float %v %23 +%25 = OpLoad %float %24 +OpStore %f %25 +%26 = OpLoad %v4float %BaseColor1 +%27 = OpFAdd %v4float %26 %19 +OpStore %v %27 +%28 = OpLoad %v4float %v +%29 = OpLoad %float %f +%30 = OpCompositeConstruct %v4float %29 %29 %29 %29 +%31 = OpFDiv %v4float %28 %30 +OpStore %OutColor %31 OpReturn OpFunctionEnd )"; const std::string after = - R"(%main = OpFunction %void None %9 -%18 = OpLabel + R"(%main = OpFunction %void None %10 +%21 = OpLabel %v = OpVariable %_ptr_Function_v4float Function %f = OpVariable %_ptr_Function_float Function -%19 = OpLoad %v4float %BaseColor -OpStore %v %19 -%20 = OpLoad %int %Idx -%21 = OpAccessChain %_ptr_Function_float %v %20 -%22 = OpLoad %float %21 -OpStore %f %22 -%23 = OpLoad %v4float %v -%24 = OpLoad %float %f -%25 = OpCompositeConstruct %v4float %22 %22 %22 %22 -%26 = OpFDiv %v4float %19 %25 -OpStore %gl_FragColor %26 +%22 = OpLoad %v4float %BaseColor0 +OpStore %v %22 +%23 = OpLoad %int %Idx +%24 = OpAccessChain %_ptr_Function_float %v %23 +%25 = OpLoad %float %24 +OpStore %f %25 +%26 = OpLoad %v4float %BaseColor1 +%27 = OpFAdd %v4float %26 %19 +OpStore %v %27 +%30 = OpCompositeConstruct %v4float %25 %25 %25 %25 +%31 = OpFDiv %v4float %27 %30 +OpStore %OutColor %31 OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs + before, predefs + after, true, true); } @@ -330,8 +417,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - assembly, assembly, false, true); + SinglePassRunAndCheck(assembly, assembly, + false, true); } TEST_F(LocalSingleBlockLoadStoreElimTest, NoElimIfInterveningFunctionCall) { @@ -386,8 +473,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - assembly, assembly, false, true); + SinglePassRunAndCheck(assembly, assembly, + false, true); } TEST_F(LocalSingleBlockLoadStoreElimTest, ElimIfCopyObjectInFunction) { @@ -464,21 +551,19 @@ OpFunctionEnd %v2 = OpVariable %_ptr_Function_v4float Function %23 = OpLoad %v4float %BaseColor OpStore %v1 %23 -%24 = OpLoad %v4float %v1 %25 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_0 OpStore %25 %23 %26 = OpLoad %v4float %BaseColor %27 = OpVectorTimesScalar %v4float %26 %float_0_5 %28 = OpCopyObject %_ptr_Function_v4float %v2 OpStore %28 %27 -%29 = OpLoad %v4float %28 %30 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_1 OpStore %30 %27 OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs + before, predefs + after, true, true); } @@ -530,17 +615,17 @@ OpDecorate %sampler15 DescriptorSet 0 const std::string before = R"(%main = OpFunction %void None %12 %28 = OpLabel -%s0 = OpVariable %_ptr_Function_S_t Function +%s0 = OpVariable %_ptr_Function_S_t Function %param = OpVariable %_ptr_Function_S_t Function %29 = OpLoad %v2float %texCoords -%30 = OpLoad %S_t %s0 +%30 = OpLoad %S_t %s0 %31 = OpCompositeInsert %S_t %29 %30 0 OpStore %s0 %31 %32 = OpLoad %18 %sampler15 -%33 = OpLoad %S_t %s0 +%33 = OpLoad %S_t %s0 %34 = OpCompositeInsert %S_t %32 %33 2 OpStore %s0 %34 -%35 = OpLoad %S_t %s0 +%35 = OpLoad %S_t %s0 OpStore %param %35 %36 = OpLoad %S_t %param %37 = OpCompositeExtract %18 %36 2 @@ -560,16 +645,11 @@ OpFunctionEnd %29 = OpLoad %v2float %texCoords %30 = OpLoad %S_t %s0 %31 = OpCompositeInsert %S_t %29 %30 0 -OpStore %s0 %31 %32 = OpLoad %18 %sampler15 -%33 = OpLoad %S_t %s0 %34 = OpCompositeInsert %S_t %32 %31 2 OpStore %s0 %34 -%35 = OpLoad %S_t %s0 OpStore %param %34 -%36 = OpLoad %S_t %param %37 = OpCompositeExtract %18 %34 2 -%38 = OpLoad %S_t %param %39 = OpCompositeExtract %v2float %34 0 %40 = OpImageSampleImplicitLod %v4float %37 %39 OpStore %outColor %40 @@ -578,7 +658,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs + before, predefs + after, true, true); } @@ -681,12 +761,11 @@ OpFunctionEnd %t_0 = OpVariable %_ptr_Function_v4float Function %27 = OpLoad %v4float %v1_0 OpStore %t_0 %27 -%28 = OpLoad %v4float %t_0 OpReturnValue %27 OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs + before, predefs + after, true, true); } @@ -778,7 +857,6 @@ OpDecorate %7 Binding 0 %23 = OpLabel %24 = OpVariable %_ptr_Function__ptr_Uniform__struct_5 Function OpStore %24 %7 -%26 = OpLoad %_ptr_Uniform__struct_5 %24 %27 = OpAccessChain %_ptr_Uniform_v4float %7 %int_0 %uint_0 %int_0 %28 = OpLoad %v4float %27 %29 = OpCopyObject %v4float %28 @@ -788,8 +866,199 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, after, - true, true); + SinglePassRunAndCheck(before, after, true, + true); +} + +TEST_F(LocalSingleBlockLoadStoreElimTest, RedundantStore) { + // Test that checks if a pointer variable is removed. + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +OpBranch %16 +%16 = OpLabel +%15 = OpLoad %v4float %v +OpStore %v %15 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +OpBranch %16 +%16 = OpLabel +%15 = OpLoad %v4float %v +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + predefs_before + before, predefs_before + after, true, true); +} + +TEST_F(LocalSingleBlockLoadStoreElimTest, RedundantStore2) { + // Test that checks if a pointer variable is removed. + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +OpBranch %16 +%16 = OpLabel +%15 = OpLoad %v4float %v +OpStore %v %15 +%17 = OpLoad %v4float %v +OpStore %v %17 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +OpBranch %16 +%16 = OpLabel +%15 = OpLoad %v4float %v +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + predefs_before + before, predefs_before + after, true, true); +} + +// Test that that an unused OpAccessChain between two store does does not +// hinders the removal of the first store. We need to check this because +// local-access-chain-convert does always remove the OpAccessChain instructions +// that become dead. + +TEST_F(LocalSingleBlockLoadStoreElimTest, + StoreElimIfInterveningUnusedAccessChain) { + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor0 %Idx %BaseColor1 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %v "v" +OpName %BaseColor0 "BaseColor0" +OpName %Idx "Idx" +OpName %BaseColor1 "BaseColor1" +OpName %OutColor "OutColor" +OpDecorate %BaseColor0 Location 0 +OpDecorate %Idx Flat +OpDecorate %Idx Location 2 +OpDecorate %BaseColor1 Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_float = OpTypePointer Function %float +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%Idx = OpVariable %_ptr_Input_int Input +%BaseColor1 = OpVariable %_ptr_Input_v4float Input +%float_0_100000001 = OpConstant %float 0.100000001 +%19 = OpConstantComposite %v4float %float_0_100000001 %float_0_100000001 %float_0_100000001 %float_0_100000001 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %10 +%21 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%22 = OpLoad %v4float %BaseColor0 +OpStore %v %22 +%23 = OpLoad %int %Idx +%24 = OpAccessChain %_ptr_Function_float %v %23 +%26 = OpLoad %v4float %BaseColor1 +%27 = OpFAdd %v4float %26 %19 +OpStore %v %27 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%21 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%22 = OpLoad %v4float %BaseColor0 +%23 = OpLoad %int %Idx +%24 = OpAccessChain %_ptr_Function_float %v %23 +%26 = OpLoad %v4float %BaseColor1 +%27 = OpFAdd %v4float %26 %19 +OpStore %v %27 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + predefs + before, predefs + after, true, true); } // TODO(greg-lunarg): Add tests to verify handling of these cases: // @@ -798,4 +1067,6 @@ OpFunctionEnd // Check for correctness in the presence of function calls // Others? -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/local_single_store_elim_test.cpp b/3rdparty/spirv-tools/test/opt/local_single_store_elim_test.cpp index a7e0e9093..23e82ba86 100644 --- a/3rdparty/spirv-tools/test/opt/local_single_store_elim_test.cpp +++ b/3rdparty/spirv-tools/test/opt/local_single_store_elim_test.cpp @@ -13,13 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using LocalSingleStoreElimTest = PassTest<::testing::Test>; TEST_F(LocalSingleStoreElimTest, PositiveAndNegative) { @@ -112,7 +114,6 @@ OpBranchConditional %23 %25 %24 OpStore %f %float_0 OpBranch %24 %24 = OpLabel -%26 = OpLoad %v4float %v %27 = OpLoad %float %f %28 = OpCompositeConstruct %v4float %27 %27 %27 %27 %29 = OpFAdd %v4float %20 %28 @@ -121,8 +122,74 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSingleStoreElimTest, ThreeStores) { + // Three stores to multiple loads of v is not optimized. + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %fi "fi" +OpName %r "r" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%19 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%r = OpVariable %_ptr_Function_v4float Function +%20 = OpLoad %v4float %BaseColor +OpStore %v %20 +%21 = OpLoad %float %fi +%22 = OpFOrdLessThan %bool %21 %float_0 +OpSelectionMerge %23 None +OpBranchConditional %22 %24 %25 +%24 = OpLabel +%26 = OpLoad %v4float %v +OpStore %v %26 +OpStore %r %26 +OpBranch %23 +%25 = OpLabel +%27 = OpLoad %v4float %v +%28 = OpCompositeConstruct %v4float %float_1 %float_1 %float_1 %float_1 +OpStore %v %28 +%29 = OpFSub %v4float %28 %27 +OpStore %r %29 +OpBranch %23 +%23 = OpLabel +%30 = OpLoad %v4float %r +OpStore %gl_FragColor %30 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + before, true, true); } TEST_F(LocalSingleStoreElimTest, MultipleLoads) { @@ -211,11 +278,9 @@ OpStore %v %20 OpSelectionMerge %23 None OpBranchConditional %22 %24 %25 %24 = OpLabel -%26 = OpLoad %v4float %v OpStore %r %20 OpBranch %23 %25 = OpLabel -%27 = OpLoad %v4float %v %28 = OpCompositeConstruct %v4float %float_1 %float_1 %float_1 %float_1 %29 = OpFSub %v4float %28 %20 OpStore %r %29 @@ -227,8 +292,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(LocalSingleStoreElimTest, NoStoreElimWithInterveningAccessChainLoad) { @@ -299,16 +364,14 @@ OpStore %v %18 %19 = OpAccessChain %_ptr_Function_float %v %uint_3 %20 = OpLoad %float %19 OpStore %f %20 -%21 = OpLoad %v4float %v -%22 = OpLoad %float %f %23 = OpVectorTimesScalar %v4float %18 %20 OpStore %gl_FragColor %23 OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(LocalSingleStoreElimTest, NoReplaceOfDominatingPartialStore) { @@ -362,8 +425,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, - true); + SinglePassRunAndCheck(assembly, assembly, true, + true); } TEST_F(LocalSingleStoreElimTest, ElimIfCopyObjectInFunction) { @@ -457,7 +520,6 @@ OpStore %f %float_0 OpBranch %24 %24 = OpLabel %26 = OpCopyObject %_ptr_Function_v4float %v -%27 = OpLoad %v4float %26 %28 = OpLoad %float %f %29 = OpCompositeConstruct %v4float %28 %28 %28 %28 %30 = OpFAdd %v4float %20 %29 @@ -466,8 +528,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(LocalSingleStoreElimTest, NoOptIfStoreNotDominating) { @@ -546,8 +608,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(assembly, assembly, true, - true); + SinglePassRunAndCheck(assembly, assembly, true, + true); } TEST_F(LocalSingleStoreElimTest, OptInitializedVariableLikeStore) { @@ -599,15 +661,14 @@ OpFunctionEnd R"(%main = OpFunction %void None %6 %12 = OpLabel %f = OpVariable %_ptr_Function_float Function %float_0 -%13 = OpLoad %float %f %14 = OpCompositeConstruct %v4float %float_0 %float_0 %float_0 %float_0 OpStore %gl_FragColor %14 OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(LocalSingleStoreElimTest, PointerVariable) { @@ -698,7 +759,6 @@ OpDecorate %7 Binding 0 %23 = OpLabel %24 = OpVariable %_ptr_Function__ptr_Uniform__struct_5 Function OpStore %24 %7 -%26 = OpLoad %_ptr_Uniform__struct_5 %24 %27 = OpAccessChain %_ptr_Uniform_v4float %7 %int_0 %uint_0 %int_0 %28 = OpLoad %v4float %27 %29 = OpCopyObject %v4float %28 @@ -708,13 +768,88 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, after, true, - true); + SinglePassRunAndCheck(before, after, true, true); } +// Test that that an unused OpAccessChain between a store and a use does does +// not hinders the replacement of the use. We need to check this because +// local-access-chain-convert does always remove the OpAccessChain instructions +// that become dead. + +TEST_F(LocalSingleStoreElimTest, + StoreElimWithUnusedInterveningAccessChainLoad) { + // Last load of v is eliminated, but access chain load and store of v isn't + // + // #version 140 + // + // in vec4 BaseColor; + // + // void main() + // { + // vec4 v = BaseColor; + // float f = v[3]; + // gl_FragColor = v * f; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_float = OpTypePointer Function %float +%uint = OpTypeInt 32 0 +%uint_3 = OpConstant %uint 3 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%18 = OpLoad %v4float %BaseColor +OpStore %v %18 +%19 = OpAccessChain %_ptr_Function_float %v %uint_3 +%21 = OpLoad %v4float %v +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%18 = OpLoad %v4float %BaseColor +OpStore %v %18 +%19 = OpAccessChain %_ptr_Function_float %v %uint_3 +OpStore %gl_FragColor %18 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} // TODO(greg-lunarg): Add tests to verify handling of these cases: // // Other types // Others? -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/local_ssa_elim_test.cpp b/3rdparty/spirv-tools/test/opt/local_ssa_elim_test.cpp index 06ecc254b..33419395b 100644 --- a/3rdparty/spirv-tools/test/opt/local_ssa_elim_test.cpp +++ b/3rdparty/spirv-tools/test/opt/local_ssa_elim_test.cpp @@ -13,13 +13,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using LocalSSAElimTest = PassTest<::testing::Test>; TEST_F(LocalSSAElimTest, ForLoop) { @@ -135,8 +138,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(LocalSSAElimTest, NestedForLoop) { @@ -277,8 +280,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(LocalSSAElimTest, ForLoopWithContinue) { @@ -423,7 +426,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs + names + predefs2 + before, predefs + names + predefs2 + after, true, true); } @@ -564,8 +567,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(LocalSSAElimTest, SwapProblem) { @@ -701,8 +704,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(LocalSSAElimTest, LostCopyProblem) { @@ -845,8 +848,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(LocalSSAElimTest, IfThenElse) { @@ -945,8 +948,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(LocalSSAElimTest, IfThen) { @@ -1034,8 +1037,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(LocalSSAElimTest, Switch) { @@ -1165,8 +1168,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(LocalSSAElimTest, SwitchWithFallThrough) { @@ -1297,8 +1300,8 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( - predefs + before, predefs + after, true, true); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); } TEST_F(LocalSSAElimTest, DontPatchPhiInLoopHeaderThatIsNotAVar) { @@ -1328,8 +1331,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before, before, true, - true); + SinglePassRunAndCheck(before, before, true, true); } TEST_F(LocalSSAElimTest, OptInitializedVariableLikeStore) { @@ -1426,7 +1428,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck( + SinglePassRunAndCheck( predefs + func_before, predefs + func_after, true, true); } @@ -1527,8 +1529,7 @@ OpFunctionEnd )"; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck(before, after, true, - true); + SinglePassRunAndCheck(before, after, true, true); } TEST_F(LocalSSAElimTest, VerifyInstToBlockMap) { @@ -1609,7 +1610,7 @@ OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(nullptr, context); @@ -1617,10 +1618,10 @@ OpFunctionEnd // Force the instruction to block mapping to get built. context->get_instr_block(27u); - auto pass = MakeUnique(); + auto pass = MakeUnique(); pass->SetMessageConsumer(nullptr); const auto status = pass->Run(context.get()); - EXPECT_TRUE(status == opt::Pass::Status::SuccessWithChange); + EXPECT_TRUE(status == Pass::Status::SuccessWithChange); } // TODO(dneto): Add Effcee as required dependency, and make this unconditional. @@ -1630,7 +1631,7 @@ TEST_F(LocalSSAElimTest, CompositeExtractProblem) { OpCapability Tessellation %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 - OpEntryPoint TessellationControl %2 "main" + OpEntryPoint TessellationControl %2 "main" %16 %17 %18 %20 %22 %26 %27 %30 %31 %void = OpTypeVoid %4 = OpTypeFunction %void %float = OpTypeFloat 32 @@ -1716,7 +1717,46 @@ TEST_F(LocalSSAElimTest, CompositeExtractProblem) { OpReturn OpFunctionEnd)"; - SinglePassRunAndMatch(spv_asm, true); + SinglePassRunAndMatch(spv_asm, true); +} + +// Test that the RelaxedPrecision decoration on the variable to added to the +// result of the OpPhi instruction. +TEST_F(LocalSSAElimTest, DecoratedVariable) { + const std::string spv_asm = R"( +; CHECK: OpDecorate [[var:%\w+]] RelaxedPrecision +; CHECK: OpDecorate [[phi_id:%\w+]] RelaxedPrecision +; CHECK: [[phi_id]] = OpPhi + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %2 "main" + OpDecorate %v RelaxedPrecision + %void = OpTypeVoid + %func_t = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %int = OpTypeInt 32 0 + %int_p = OpTypePointer Function %int + %int_1 = OpConstant %int 1 + %int_0 = OpConstant %int 0 + %2 = OpFunction %void None %func_t + %33 = OpLabel + %v = OpVariable %int_p Function + OpSelectionMerge %merge None + OpBranchConditional %true %l1 %l2 + %l1 = OpLabel + OpStore %v %int_1 + OpBranch %merge + %l2 = OpLabel + OpStore %v %int_0 + OpBranch %merge + %merge = OpLabel + %ld = OpLoad %int %v + OpReturn + OpFunctionEnd)"; + + SinglePassRunAndMatch(spv_asm, true); } #endif @@ -1729,4 +1769,6 @@ TEST_F(LocalSSAElimTest, CompositeExtractProblem) { // unsupported extensions // Others? -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/CMakeLists.txt b/3rdparty/spirv-tools/test/opt/loop_optimizations/CMakeLists.txt index 53054814e..8c7971b7f 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/CMakeLists.txt +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/CMakeLists.txt @@ -13,80 +13,28 @@ # limitations under the License. -add_spvtools_unittest(TARGET loop_descriptor_simple - SRCS ../function_utils.h - loop_descriptions.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET loop_descriptor_nested - SRCS ../function_utils.h - nested_loops.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET lcssa_test - SRCS ../function_utils.h - lcssa.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET licm_all_loop_types - SRCS ../function_utils.h - hoist_all_loop_types.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET licm_hoist_independent_loops - SRCS ../function_utils.h - hoist_from_independent_loops.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET licm_hoist_double_nested_loops - SRCS ../function_utils.h - hoist_double_nested_loops.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET licm_hoist_single_nested_loops - SRCS ../function_utils.h - hoist_single_nested_loops.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET licm_hoist_simple_case - SRCS ../function_utils.h - hoist_simple_case.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET licm_hoist_no_preheader - SRCS ../function_utils.h - hoist_without_preheader.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET loop_unroll_simple - SRCS ../function_utils.h - unroll_simple.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET loop_unroll_assumtion_checks - SRCS ../function_utils.h - unroll_assumptions.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET unswitch_test - SRCS ../function_utils.h - unswitch.cpp - LIBS SPIRV-Tools-opt -) - -add_spvtools_unittest(TARGET peeling_test - SRCS ../function_utils.h - peeling.cpp - LIBS SPIRV-Tools-opt +add_spvtools_unittest(TARGET opt_loops + SRCS ../function_utils.h + dependence_analysis.cpp + dependence_analysis_helpers.cpp + fusion_compatibility.cpp + fusion_illegal.cpp + fusion_legal.cpp + fusion_pass.cpp + hoist_all_loop_types.cpp + hoist_double_nested_loops.cpp + hoist_from_independent_loops.cpp + hoist_simple_case.cpp + hoist_single_nested_loops.cpp + hoist_without_preheader.cpp + lcssa.cpp + loop_descriptions.cpp + loop_fission.cpp + nested_loops.cpp + peeling.cpp + peeling_pass.cpp + unroll_assumptions.cpp + unroll_simple.cpp + unswitch.cpp + LIBS SPIRV-Tools-opt ) diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/dependence_analysis.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/dependence_analysis.cpp new file mode 100644 index 000000000..8aeb20afc --- /dev/null +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/dependence_analysis.cpp @@ -0,0 +1,4205 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/iterator.h" +#include "source/opt/loop_dependence.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" +#include "source/opt/tree_iterator.h" +#include "test/opt//assembly_builder.h" +#include "test/opt//function_utils.h" +#include "test/opt//pass_fixture.h" +#include "test/opt//pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using DependencyAnalysis = ::testing::Test; + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void main(){ + int[10] arr; + int[10] arr2; + int a = 2; + for (int i = 0; i < 10; i++) { + arr[a] = arr[3]; + arr[a*2] = arr[a+3]; + arr[6] = arr2[6]; + arr[a+5] = arr2[7]; + } +} +*/ +TEST(DependencyAnalysis, ZIV) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %25 "arr" + OpName %39 "arr2" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 2 + %11 = OpConstant %6 0 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 10 + %23 = OpTypeArray %6 %22 + %24 = OpTypePointer Function %23 + %27 = OpConstant %6 3 + %38 = OpConstant %6 6 + %44 = OpConstant %6 5 + %46 = OpConstant %6 7 + %51 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %25 = OpVariable %24 Function + %39 = OpVariable %24 Function + OpBranch %12 + %12 = OpLabel + %53 = OpPhi %6 %11 %5 %52 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %53 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + %28 = OpAccessChain %7 %25 %27 + %29 = OpLoad %6 %28 + %30 = OpAccessChain %7 %25 %9 + OpStore %30 %29 + %32 = OpIMul %6 %9 %9 + %34 = OpIAdd %6 %9 %27 + %35 = OpAccessChain %7 %25 %34 + %36 = OpLoad %6 %35 + %37 = OpAccessChain %7 %25 %32 + OpStore %37 %36 + %40 = OpAccessChain %7 %39 %38 + %41 = OpLoad %6 %40 + %42 = OpAccessChain %7 %25 %38 + OpStore %42 %41 + %45 = OpIAdd %6 %9 %44 + %47 = OpAccessChain %7 %39 %46 + %48 = OpLoad %6 %47 + %49 = OpAccessChain %7 %25 %45 + OpStore %49 %48 + OpBranch %15 + %15 = OpLabel + %52 = OpIAdd %6 %53 %51 + OpBranch %12 + %14 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 13)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // 29 -> 30 tests looking through constants. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(29), + store[0], &distance_vector)); + } + + // 36 -> 37 tests looking through additions. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(36), + store[1], &distance_vector)); + } + + // 41 -> 42 tests looking at same index across two different arrays. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(41), + store[2], &distance_vector)); + } + + // 48 -> 49 tests looking through additions for same index in two different + // arrays. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(48), + store[3], &distance_vector)); + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +layout(location = 0) in vec4 c; +void main(){ + int[10] arr; + int[10] arr2; + int[10] arr3; + int[10] arr4; + int[10] arr5; + int N = int(c.x); + for (int i = 0; i < N; i++) { + arr[2*N] = arr[N]; + arr2[2*N+1] = arr2[N]; + arr3[2*N] = arr3[N-1]; + arr4[N] = arr5[N]; + } +} +*/ +TEST(DependencyAnalysis, SymbolicZIV) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %12 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %12 "c" + OpName %33 "arr" + OpName %41 "arr2" + OpName %50 "arr3" + OpName %58 "arr4" + OpName %60 "arr5" + OpDecorate %12 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpTypeFloat 32 + %10 = OpTypeVector %9 4 + %11 = OpTypePointer Input %10 + %12 = OpVariable %11 Input + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 0 + %15 = OpTypePointer Input %9 + %20 = OpConstant %6 0 + %28 = OpTypeBool + %30 = OpConstant %13 10 + %31 = OpTypeArray %6 %30 + %32 = OpTypePointer Function %31 + %34 = OpConstant %6 2 + %44 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %33 = OpVariable %32 Function + %41 = OpVariable %32 Function + %50 = OpVariable %32 Function + %58 = OpVariable %32 Function + %60 = OpVariable %32 Function + %16 = OpAccessChain %15 %12 %14 + %17 = OpLoad %9 %16 + %18 = OpConvertFToS %6 %17 + OpBranch %21 + %21 = OpLabel + %67 = OpPhi %6 %20 %5 %66 %24 + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %29 = OpSLessThan %28 %67 %18 + OpBranchConditional %29 %22 %23 + %22 = OpLabel + %36 = OpIMul %6 %34 %18 + %38 = OpAccessChain %7 %33 %18 + %39 = OpLoad %6 %38 + %40 = OpAccessChain %7 %33 %36 + OpStore %40 %39 + %43 = OpIMul %6 %34 %18 + %45 = OpIAdd %6 %43 %44 + %47 = OpAccessChain %7 %41 %18 + %48 = OpLoad %6 %47 + %49 = OpAccessChain %7 %41 %45 + OpStore %49 %48 + %52 = OpIMul %6 %34 %18 + %54 = OpISub %6 %18 %44 + %55 = OpAccessChain %7 %50 %54 + %56 = OpLoad %6 %55 + %57 = OpAccessChain %7 %50 %52 + OpStore %57 %56 + %62 = OpAccessChain %7 %60 %18 + %63 = OpLoad %6 %62 + %64 = OpAccessChain %7 %58 %18 + OpStore %64 %63 + OpBranch %24 + %24 = OpLabel + %66 = OpIAdd %6 %67 %44 + OpBranch %21 + %23 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 22)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // independent due to loop bounds (won't enter if N <= 0). + // 39 -> 40 tests looking through symbols and multiplicaiton. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(39), + store[0], &distance_vector)); + } + + // 48 -> 49 tests looking through symbols and multiplication + addition. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(48), + store[1], &distance_vector)); + } + + // 56 -> 57 tests looking through symbols and arithmetic on load and store. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(56), + store[2], &distance_vector)); + } + + // independent as different arrays + // 63 -> 64 tests looking through symbols and load/store from/to different + // arrays. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(63), + store[3], &distance_vector)); + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void a(){ + int[10] arr; + int[11] arr2; + int[20] arr3; + int[20] arr4; + int a = 2; + for (int i = 0; i < 10; i++) { + arr[i] = arr[i]; + arr2[i] = arr2[i+1]; + arr3[i] = arr3[i-1]; + arr4[2*i] = arr4[i]; + } +} +void b(){ + int[10] arr; + int[11] arr2; + int[20] arr3; + int[20] arr4; + int a = 2; + for (int i = 10; i > 0; i--) { + arr[i] = arr[i]; + arr2[i] = arr2[i+1]; + arr3[i] = arr3[i-1]; + arr4[2*i] = arr4[i]; + } +} + +void main() { + a(); + b(); +} +*/ +TEST(DependencyAnalysis, SIV) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %12 "a" + OpName %14 "i" + OpName %29 "arr" + OpName %38 "arr2" + OpName %49 "arr3" + OpName %56 "arr4" + OpName %65 "a" + OpName %66 "i" + OpName %74 "arr" + OpName %80 "arr2" + OpName %87 "arr3" + OpName %94 "arr4" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %13 = OpConstant %10 2 + %15 = OpConstant %10 0 + %22 = OpConstant %10 10 + %23 = OpTypeBool + %25 = OpTypeInt 32 0 + %26 = OpConstant %25 10 + %27 = OpTypeArray %10 %26 + %28 = OpTypePointer Function %27 + %35 = OpConstant %25 11 + %36 = OpTypeArray %10 %35 + %37 = OpTypePointer Function %36 + %41 = OpConstant %10 1 + %46 = OpConstant %25 20 + %47 = OpTypeArray %10 %46 + %48 = OpTypePointer Function %47 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %103 = OpFunctionCall %2 %6 + %104 = OpFunctionCall %2 %8 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %12 = OpVariable %11 Function + %14 = OpVariable %11 Function + %29 = OpVariable %28 Function + %38 = OpVariable %37 Function + %49 = OpVariable %48 Function + %56 = OpVariable %48 Function + OpStore %12 %13 + OpStore %14 %15 + OpBranch %16 + %16 = OpLabel + %105 = OpPhi %10 %15 %7 %64 %19 + OpLoopMerge %18 %19 None + OpBranch %20 + %20 = OpLabel + %24 = OpSLessThan %23 %105 %22 + OpBranchConditional %24 %17 %18 + %17 = OpLabel + %32 = OpAccessChain %11 %29 %105 + %33 = OpLoad %10 %32 + %34 = OpAccessChain %11 %29 %105 + OpStore %34 %33 + %42 = OpIAdd %10 %105 %41 + %43 = OpAccessChain %11 %38 %42 + %44 = OpLoad %10 %43 + %45 = OpAccessChain %11 %38 %105 + OpStore %45 %44 + %52 = OpISub %10 %105 %41 + %53 = OpAccessChain %11 %49 %52 + %54 = OpLoad %10 %53 + %55 = OpAccessChain %11 %49 %105 + OpStore %55 %54 + %58 = OpIMul %10 %13 %105 + %60 = OpAccessChain %11 %56 %105 + %61 = OpLoad %10 %60 + %62 = OpAccessChain %11 %56 %58 + OpStore %62 %61 + OpBranch %19 + %19 = OpLabel + %64 = OpIAdd %10 %105 %41 + OpStore %14 %64 + OpBranch %16 + %18 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %65 = OpVariable %11 Function + %66 = OpVariable %11 Function + %74 = OpVariable %28 Function + %80 = OpVariable %37 Function + %87 = OpVariable %48 Function + %94 = OpVariable %48 Function + OpStore %65 %13 + OpStore %66 %22 + OpBranch %67 + %67 = OpLabel + %106 = OpPhi %10 %22 %9 %102 %70 + OpLoopMerge %69 %70 None + OpBranch %71 + %71 = OpLabel + %73 = OpSGreaterThan %23 %106 %15 + OpBranchConditional %73 %68 %69 + %68 = OpLabel + %77 = OpAccessChain %11 %74 %106 + %78 = OpLoad %10 %77 + %79 = OpAccessChain %11 %74 %106 + OpStore %79 %78 + %83 = OpIAdd %10 %106 %41 + %84 = OpAccessChain %11 %80 %83 + %85 = OpLoad %10 %84 + %86 = OpAccessChain %11 %80 %106 + OpStore %86 %85 + %90 = OpISub %10 %106 %41 + %91 = OpAccessChain %11 %87 %90 + %92 = OpLoad %10 %91 + %93 = OpAccessChain %11 %87 %106 + OpStore %93 %92 + %96 = OpIMul %10 %13 %106 + %98 = OpAccessChain %11 %94 %106 + %99 = OpLoad %10 %98 + %100 = OpAccessChain %11 %94 %96 + OpStore %100 %99 + OpBranch %70 + %70 = OpLabel + %102 = OpISub %10 %106 %41 + OpStore %66 %102 + OpBranch %67 + %69 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + // For the loop in function a. + { + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 17)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // = dependence + // 33 -> 34 tests looking at SIV in same array. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(33), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::EQ); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 0); + } + + // > -1 dependence + // 44 -> 45 tests looking at SIV in same array with addition. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(44), store[1], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::GT); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, -1); + } + + // < 1 dependence + // 54 -> 55 tests looking at SIV in same array with subtraction. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(54), store[2], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::LT); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 1); + } + + // <=> dependence + // 61 -> 62 tests looking at SIV in same array with multiplication. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(61), store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::UNKNOWN); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::ALL); + } + } + // For the loop in function b. + { + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 68)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // = dependence + // 78 -> 79 tests looking at SIV in same array. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(78), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::EQ); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 0); + } + + // < 1 dependence + // 85 -> 86 tests looking at SIV in same array with addition. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(85), store[1], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::LT); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 1); + } + + // > -1 dependence + // 92 -> 93 tests looking at SIV in same array with subtraction. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(92), store[2], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::GT); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, -1); + } + + // <=> dependence + // 99 -> 100 tests looking at SIV in same array with multiplication. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(99), store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::UNKNOWN); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::ALL); + } + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +layout(location = 0) in vec4 c; +void a() { + int[13] arr; + int[15] arr2; + int[18] arr3; + int[18] arr4; + int N = int(c.x); + int C = 2; + int a = 2; + for (int i = 0; i < N; i++) { // Bounds are N - 1 + arr[i+2*N] = arr[i+N]; // |distance| = N + arr2[i+N] = arr2[i+2*N] + C; // |distance| = N + arr3[2*i+2*N+1] = arr3[2*i+N+1]; // |distance| = N + arr4[a*i+N+1] = arr4[a*i+2*N+1]; // |distance| = N + } +} +void b() { + int[13] arr; + int[15] arr2; + int[18] arr3; + int[18] arr4; + int N = int(c.x); + int C = 2; + int a = 2; + for (int i = N; i > 0; i--) { // Bounds are N - 1 + arr[i+2*N] = arr[i+N]; // |distance| = N + arr2[i+N] = arr2[i+2*N] + C; // |distance| = N + arr3[2*i+2*N+1] = arr3[2*i+N+1]; // |distance| = N + arr4[a*i+N+1] = arr4[a*i+2*N+1]; // |distance| = N + } +} +void main(){ + a(); + b(); +}*/ +TEST(DependencyAnalysis, SymbolicSIV) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %16 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %12 "N" + OpName %16 "c" + OpName %23 "C" + OpName %25 "a" + OpName %26 "i" + OpName %40 "arr" + OpName %54 "arr2" + OpName %70 "arr3" + OpName %86 "arr4" + OpName %105 "N" + OpName %109 "C" + OpName %110 "a" + OpName %111 "i" + OpName %120 "arr" + OpName %131 "arr2" + OpName %144 "arr3" + OpName %159 "arr4" + OpDecorate %16 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %13 = OpTypeFloat 32 + %14 = OpTypeVector %13 4 + %15 = OpTypePointer Input %14 + %16 = OpVariable %15 Input + %17 = OpTypeInt 32 0 + %18 = OpConstant %17 0 + %19 = OpTypePointer Input %13 + %24 = OpConstant %10 2 + %27 = OpConstant %10 0 + %35 = OpTypeBool + %37 = OpConstant %17 13 + %38 = OpTypeArray %10 %37 + %39 = OpTypePointer Function %38 + %51 = OpConstant %17 15 + %52 = OpTypeArray %10 %51 + %53 = OpTypePointer Function %52 + %67 = OpConstant %17 18 + %68 = OpTypeArray %10 %67 + %69 = OpTypePointer Function %68 + %76 = OpConstant %10 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %178 = OpFunctionCall %2 %6 + %179 = OpFunctionCall %2 %8 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %12 = OpVariable %11 Function + %23 = OpVariable %11 Function + %25 = OpVariable %11 Function + %26 = OpVariable %11 Function + %40 = OpVariable %39 Function + %54 = OpVariable %53 Function + %70 = OpVariable %69 Function + %86 = OpVariable %69 Function + %20 = OpAccessChain %19 %16 %18 + %21 = OpLoad %13 %20 + %22 = OpConvertFToS %10 %21 + OpStore %12 %22 + OpStore %23 %24 + OpStore %25 %24 + OpStore %26 %27 + OpBranch %28 + %28 = OpLabel + %180 = OpPhi %10 %27 %7 %104 %31 + OpLoopMerge %30 %31 None + OpBranch %32 + %32 = OpLabel + %36 = OpSLessThan %35 %180 %22 + OpBranchConditional %36 %29 %30 + %29 = OpLabel + %43 = OpIMul %10 %24 %22 + %44 = OpIAdd %10 %180 %43 + %47 = OpIAdd %10 %180 %22 + %48 = OpAccessChain %11 %40 %47 + %49 = OpLoad %10 %48 + %50 = OpAccessChain %11 %40 %44 + OpStore %50 %49 + %57 = OpIAdd %10 %180 %22 + %60 = OpIMul %10 %24 %22 + %61 = OpIAdd %10 %180 %60 + %62 = OpAccessChain %11 %54 %61 + %63 = OpLoad %10 %62 + %65 = OpIAdd %10 %63 %24 + %66 = OpAccessChain %11 %54 %57 + OpStore %66 %65 + %72 = OpIMul %10 %24 %180 + %74 = OpIMul %10 %24 %22 + %75 = OpIAdd %10 %72 %74 + %77 = OpIAdd %10 %75 %76 + %79 = OpIMul %10 %24 %180 + %81 = OpIAdd %10 %79 %22 + %82 = OpIAdd %10 %81 %76 + %83 = OpAccessChain %11 %70 %82 + %84 = OpLoad %10 %83 + %85 = OpAccessChain %11 %70 %77 + OpStore %85 %84 + %89 = OpIMul %10 %24 %180 + %91 = OpIAdd %10 %89 %22 + %92 = OpIAdd %10 %91 %76 + %95 = OpIMul %10 %24 %180 + %97 = OpIMul %10 %24 %22 + %98 = OpIAdd %10 %95 %97 + %99 = OpIAdd %10 %98 %76 + %100 = OpAccessChain %11 %86 %99 + %101 = OpLoad %10 %100 + %102 = OpAccessChain %11 %86 %92 + OpStore %102 %101 + OpBranch %31 + %31 = OpLabel + %104 = OpIAdd %10 %180 %76 + OpStore %26 %104 + OpBranch %28 + %30 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %105 = OpVariable %11 Function + %109 = OpVariable %11 Function + %110 = OpVariable %11 Function + %111 = OpVariable %11 Function + %120 = OpVariable %39 Function + %131 = OpVariable %53 Function + %144 = OpVariable %69 Function + %159 = OpVariable %69 Function + %106 = OpAccessChain %19 %16 %18 + %107 = OpLoad %13 %106 + %108 = OpConvertFToS %10 %107 + OpStore %105 %108 + OpStore %109 %24 + OpStore %110 %24 + OpStore %111 %108 + OpBranch %113 + %113 = OpLabel + %181 = OpPhi %10 %108 %9 %177 %116 + OpLoopMerge %115 %116 None + OpBranch %117 + %117 = OpLabel + %119 = OpSGreaterThan %35 %181 %27 + OpBranchConditional %119 %114 %115 + %114 = OpLabel + %123 = OpIMul %10 %24 %108 + %124 = OpIAdd %10 %181 %123 + %127 = OpIAdd %10 %181 %108 + %128 = OpAccessChain %11 %120 %127 + %129 = OpLoad %10 %128 + %130 = OpAccessChain %11 %120 %124 + OpStore %130 %129 + %134 = OpIAdd %10 %181 %108 + %137 = OpIMul %10 %24 %108 + %138 = OpIAdd %10 %181 %137 + %139 = OpAccessChain %11 %131 %138 + %140 = OpLoad %10 %139 + %142 = OpIAdd %10 %140 %24 + %143 = OpAccessChain %11 %131 %134 + OpStore %143 %142 + %146 = OpIMul %10 %24 %181 + %148 = OpIMul %10 %24 %108 + %149 = OpIAdd %10 %146 %148 + %150 = OpIAdd %10 %149 %76 + %152 = OpIMul %10 %24 %181 + %154 = OpIAdd %10 %152 %108 + %155 = OpIAdd %10 %154 %76 + %156 = OpAccessChain %11 %144 %155 + %157 = OpLoad %10 %156 + %158 = OpAccessChain %11 %144 %150 + OpStore %158 %157 + %162 = OpIMul %10 %24 %181 + %164 = OpIAdd %10 %162 %108 + %165 = OpIAdd %10 %164 %76 + %168 = OpIMul %10 %24 %181 + %170 = OpIMul %10 %24 %108 + %171 = OpIAdd %10 %168 %170 + %172 = OpIAdd %10 %171 %76 + %173 = OpAccessChain %11 %159 %172 + %174 = OpLoad %10 %173 + %175 = OpAccessChain %11 %159 %165 + OpStore %175 %174 + OpBranch %116 + %116 = OpLabel + %177 = OpISub %10 %181 %76 + OpStore %111 %177 + OpBranch %113 + %115 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + // For the loop in function a. + { + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 29)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // independent due to loop bounds (won't enter when N <= 0) + // 49 -> 50 tests looking through SIV and symbols with multiplication + { + DistanceVector distance_vector{loops.size()}; + // Independent but not yet supported. + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(49), store[0], &distance_vector)); + } + + // 63 -> 66 tests looking through SIV and symbols with multiplication and + + // C + { + DistanceVector distance_vector{loops.size()}; + // Independent. + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(63), + store[1], &distance_vector)); + } + + // 84 -> 85 tests looking through arithmetic on SIV and symbols + { + DistanceVector distance_vector{loops.size()}; + // Independent but not yet supported. + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(84), store[2], &distance_vector)); + } + + // 101 -> 102 tests looking through symbol arithmetic on SIV and symbols + { + DistanceVector distance_vector{loops.size()}; + // Independent. + EXPECT_TRUE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(101), store[3], &distance_vector)); + } + } + // For the loop in function b. + { + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 114)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // independent due to loop bounds (won't enter when N <= 0). + // 129 -> 130 tests looking through SIV and symbols with multiplication. + { + DistanceVector distance_vector{loops.size()}; + // Independent but not yet supported. + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(129), store[0], &distance_vector)); + } + + // 140 -> 143 tests looking through SIV and symbols with multiplication and + // + C. + { + DistanceVector distance_vector{loops.size()}; + // Independent. + EXPECT_TRUE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(140), store[1], &distance_vector)); + } + + // 157 -> 158 tests looking through arithmetic on SIV and symbols. + { + DistanceVector distance_vector{loops.size()}; + // Independent but not yet supported. + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(157), store[2], &distance_vector)); + } + + // 174 -> 175 tests looking through symbol arithmetic on SIV and symbols. + { + DistanceVector distance_vector{loops.size()}; + // Independent. + EXPECT_TRUE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(174), store[3], &distance_vector)); + } + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void a() { + int[6] arr; + int N = 5; + for (int i = 1; i < N; i++) { + arr[i] = arr[N-i]; + } +} +void b() { + int[6] arr; + int N = 5; + for (int i = 1; i < N; i++) { + arr[N-i] = arr[i]; + } +} +void c() { + int[11] arr; + int N = 10; + for (int i = 1; i < N; i++) { + arr[i] = arr[N-i+1]; + } +} +void d() { + int[11] arr; + int N = 10; + for (int i = 1; i < N; i++) { + arr[N-i+1] = arr[i]; + } +} +void e() { + int[6] arr; + int N = 5; + for (int i = N; i > 0; i--) { + arr[i] = arr[N-i]; + } +} +void f() { + int[6] arr; + int N = 5; + for (int i = N; i > 0; i--) { + arr[N-i] = arr[i]; + } +} +void g() { + int[11] arr; + int N = 10; + for (int i = N; i > 0; i--) { + arr[i] = arr[N-i+1]; + } +} +void h() { + int[11] arr; + int N = 10; + for (int i = N; i > 0; i--) { + arr[N-i+1] = arr[i]; + } +} +void main(){ + a(); + b(); + c(); + d(); + e(); + f(); + g(); + h(); +} +*/ +TEST(DependencyAnalysis, Crossing) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %14 "e(" + OpName %16 "f(" + OpName %18 "g(" + OpName %20 "h(" + OpName %24 "N" + OpName %26 "i" + OpName %41 "arr" + OpName %51 "N" + OpName %52 "i" + OpName %61 "arr" + OpName %71 "N" + OpName %73 "i" + OpName %85 "arr" + OpName %96 "N" + OpName %97 "i" + OpName %106 "arr" + OpName %117 "N" + OpName %118 "i" + OpName %128 "arr" + OpName %138 "N" + OpName %139 "i" + OpName %148 "arr" + OpName %158 "N" + OpName %159 "i" + OpName %168 "arr" + OpName %179 "N" + OpName %180 "i" + OpName %189 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %22 = OpTypeInt 32 1 + %23 = OpTypePointer Function %22 + %25 = OpConstant %22 5 + %27 = OpConstant %22 1 + %35 = OpTypeBool + %37 = OpTypeInt 32 0 + %38 = OpConstant %37 6 + %39 = OpTypeArray %22 %38 + %40 = OpTypePointer Function %39 + %72 = OpConstant %22 10 + %82 = OpConstant %37 11 + %83 = OpTypeArray %22 %82 + %84 = OpTypePointer Function %83 + %126 = OpConstant %22 0 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %200 = OpFunctionCall %2 %6 + %201 = OpFunctionCall %2 %8 + %202 = OpFunctionCall %2 %10 + %203 = OpFunctionCall %2 %12 + %204 = OpFunctionCall %2 %14 + %205 = OpFunctionCall %2 %16 + %206 = OpFunctionCall %2 %18 + %207 = OpFunctionCall %2 %20 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %24 = OpVariable %23 Function + %26 = OpVariable %23 Function + %41 = OpVariable %40 Function + OpStore %24 %25 + OpStore %26 %27 + OpBranch %28 + %28 = OpLabel + %208 = OpPhi %22 %27 %7 %50 %31 + OpLoopMerge %30 %31 None + OpBranch %32 + %32 = OpLabel + %36 = OpSLessThan %35 %208 %25 + OpBranchConditional %36 %29 %30 + %29 = OpLabel + %45 = OpISub %22 %25 %208 + %46 = OpAccessChain %23 %41 %45 + %47 = OpLoad %22 %46 + %48 = OpAccessChain %23 %41 %208 + OpStore %48 %47 + OpBranch %31 + %31 = OpLabel + %50 = OpIAdd %22 %208 %27 + OpStore %26 %50 + OpBranch %28 + %30 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %51 = OpVariable %23 Function + %52 = OpVariable %23 Function + %61 = OpVariable %40 Function + OpStore %51 %25 + OpStore %52 %27 + OpBranch %53 + %53 = OpLabel + %209 = OpPhi %22 %27 %9 %70 %56 + OpLoopMerge %55 %56 None + OpBranch %57 + %57 = OpLabel + %60 = OpSLessThan %35 %209 %25 + OpBranchConditional %60 %54 %55 + %54 = OpLabel + %64 = OpISub %22 %25 %209 + %66 = OpAccessChain %23 %61 %209 + %67 = OpLoad %22 %66 + %68 = OpAccessChain %23 %61 %64 + OpStore %68 %67 + OpBranch %56 + %56 = OpLabel + %70 = OpIAdd %22 %209 %27 + OpStore %52 %70 + OpBranch %53 + %55 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %71 = OpVariable %23 Function + %73 = OpVariable %23 Function + %85 = OpVariable %84 Function + OpStore %71 %72 + OpStore %73 %27 + OpBranch %74 + %74 = OpLabel + %210 = OpPhi %22 %27 %11 %95 %77 + OpLoopMerge %76 %77 None + OpBranch %78 + %78 = OpLabel + %81 = OpSLessThan %35 %210 %72 + OpBranchConditional %81 %75 %76 + %75 = OpLabel + %89 = OpISub %22 %72 %210 + %90 = OpIAdd %22 %89 %27 + %91 = OpAccessChain %23 %85 %90 + %92 = OpLoad %22 %91 + %93 = OpAccessChain %23 %85 %210 + OpStore %93 %92 + OpBranch %77 + %77 = OpLabel + %95 = OpIAdd %22 %210 %27 + OpStore %73 %95 + OpBranch %74 + %76 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %96 = OpVariable %23 Function + %97 = OpVariable %23 Function + %106 = OpVariable %84 Function + OpStore %96 %72 + OpStore %97 %27 + OpBranch %98 + %98 = OpLabel + %211 = OpPhi %22 %27 %13 %116 %101 + OpLoopMerge %100 %101 None + OpBranch %102 + %102 = OpLabel + %105 = OpSLessThan %35 %211 %72 + OpBranchConditional %105 %99 %100 + %99 = OpLabel + %109 = OpISub %22 %72 %211 + %110 = OpIAdd %22 %109 %27 + %112 = OpAccessChain %23 %106 %211 + %113 = OpLoad %22 %112 + %114 = OpAccessChain %23 %106 %110 + OpStore %114 %113 + OpBranch %101 + %101 = OpLabel + %116 = OpIAdd %22 %211 %27 + OpStore %97 %116 + OpBranch %98 + %100 = OpLabel + OpReturn + OpFunctionEnd + %14 = OpFunction %2 None %3 + %15 = OpLabel + %117 = OpVariable %23 Function + %118 = OpVariable %23 Function + %128 = OpVariable %40 Function + OpStore %117 %25 + OpStore %118 %25 + OpBranch %120 + %120 = OpLabel + %212 = OpPhi %22 %25 %15 %137 %123 + OpLoopMerge %122 %123 None + OpBranch %124 + %124 = OpLabel + %127 = OpSGreaterThan %35 %212 %126 + OpBranchConditional %127 %121 %122 + %121 = OpLabel + %132 = OpISub %22 %25 %212 + %133 = OpAccessChain %23 %128 %132 + %134 = OpLoad %22 %133 + %135 = OpAccessChain %23 %128 %212 + OpStore %135 %134 + OpBranch %123 + %123 = OpLabel + %137 = OpISub %22 %212 %27 + OpStore %118 %137 + OpBranch %120 + %122 = OpLabel + OpReturn + OpFunctionEnd + %16 = OpFunction %2 None %3 + %17 = OpLabel + %138 = OpVariable %23 Function + %139 = OpVariable %23 Function + %148 = OpVariable %40 Function + OpStore %138 %25 + OpStore %139 %25 + OpBranch %141 + %141 = OpLabel + %213 = OpPhi %22 %25 %17 %157 %144 + OpLoopMerge %143 %144 None + OpBranch %145 + %145 = OpLabel + %147 = OpSGreaterThan %35 %213 %126 + OpBranchConditional %147 %142 %143 + %142 = OpLabel + %151 = OpISub %22 %25 %213 + %153 = OpAccessChain %23 %148 %213 + %154 = OpLoad %22 %153 + %155 = OpAccessChain %23 %148 %151 + OpStore %155 %154 + OpBranch %144 + %144 = OpLabel + %157 = OpISub %22 %213 %27 + OpStore %139 %157 + OpBranch %141 + %143 = OpLabel + OpReturn + OpFunctionEnd + %18 = OpFunction %2 None %3 + %19 = OpLabel + %158 = OpVariable %23 Function + %159 = OpVariable %23 Function + %168 = OpVariable %84 Function + OpStore %158 %72 + OpStore %159 %72 + OpBranch %161 + %161 = OpLabel + %214 = OpPhi %22 %72 %19 %178 %164 + OpLoopMerge %163 %164 None + OpBranch %165 + %165 = OpLabel + %167 = OpSGreaterThan %35 %214 %126 + OpBranchConditional %167 %162 %163 + %162 = OpLabel + %172 = OpISub %22 %72 %214 + %173 = OpIAdd %22 %172 %27 + %174 = OpAccessChain %23 %168 %173 + %175 = OpLoad %22 %174 + %176 = OpAccessChain %23 %168 %214 + OpStore %176 %175 + OpBranch %164 + %164 = OpLabel + %178 = OpISub %22 %214 %27 + OpStore %159 %178 + OpBranch %161 + %163 = OpLabel + OpReturn + OpFunctionEnd + %20 = OpFunction %2 None %3 + %21 = OpLabel + %179 = OpVariable %23 Function + %180 = OpVariable %23 Function + %189 = OpVariable %84 Function + OpStore %179 %72 + OpStore %180 %72 + OpBranch %182 + %182 = OpLabel + %215 = OpPhi %22 %72 %21 %199 %185 + OpLoopMerge %184 %185 None + OpBranch %186 + %186 = OpLabel + %188 = OpSGreaterThan %35 %215 %126 + OpBranchConditional %188 %183 %184 + %183 = OpLabel + %192 = OpISub %22 %72 %215 + %193 = OpIAdd %22 %192 %27 + %195 = OpAccessChain %23 %189 %215 + %196 = OpLoad %22 %195 + %197 = OpAccessChain %23 %189 %193 + OpStore %197 %196 + OpBranch %185 + %185 = OpLabel + %199 = OpISub %22 %215 %27 + OpStore %180 %199 + OpBranch %182 + %184 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + // First two tests can be split into two loops. + // Tests even crossing subscripts from low to high indexes. + // 47 -> 48 + { + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 29)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(47), + store, &distance_vector)); + } + + // Tests even crossing subscripts from high to low indexes. + // 67 -> 68 + { + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 54)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(67), + store, &distance_vector)); + } + + // Next two tests can have an end peeled, then be split. + // Tests uneven crossing subscripts from low to high indexes. + // 92 -> 93 + { + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 75)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(92), + store, &distance_vector)); + } + + // Tests uneven crossing subscripts from high to low indexes. + // 113 -> 114 + { + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 99)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(113), + store, &distance_vector)); + } + + // First two tests can be split into two loops. + // Tests even crossing subscripts from low to high indexes. + // 134 -> 135 + { + const Function* f = spvtest::GetFunction(module, 14); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 121)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(134), + store, &distance_vector)); + } + + // Tests even crossing subscripts from high to low indexes. + // 154 -> 155 + { + const Function* f = spvtest::GetFunction(module, 16); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 142)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(154), + store, &distance_vector)); + } + + // Next two tests can have an end peeled, then be split. + // Tests uneven crossing subscripts from low to high indexes. + // 175 -> 176 + { + const Function* f = spvtest::GetFunction(module, 18); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 162)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(175), + store, &distance_vector)); + } + + // Tests uneven crossing subscripts from high to low indexes. + // 196 -> 197 + { + const Function* f = spvtest::GetFunction(module, 20); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 183)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(196), + store, &distance_vector)); + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void a() { + int[10] arr; + for (int i = 0; i < 10; i++) { + arr[0] = arr[i]; // peel first + arr[i] = arr[0]; // peel first + arr[9] = arr[i]; // peel last + arr[i] = arr[9]; // peel last + } +} +void b() { + int[11] arr; + for (int i = 0; i <= 10; i++) { + arr[0] = arr[i]; // peel first + arr[i] = arr[0]; // peel first + arr[10] = arr[i]; // peel last + arr[i] = arr[10]; // peel last + + } +} +void c() { + int[11] arr; + for (int i = 10; i > 0; i--) { + arr[10] = arr[i]; // peel first + arr[i] = arr[10]; // peel first + arr[1] = arr[i]; // peel last + arr[i] = arr[1]; // peel last + + } +} +void d() { + int[11] arr; + for (int i = 10; i >= 0; i--) { + arr[10] = arr[i]; // peel first + arr[i] = arr[10]; // peel first + arr[0] = arr[i]; // peel last + arr[i] = arr[0]; // peel last + + } +} +void main(){ + a(); + b(); + c(); + d(); +} +*/ +TEST(DependencyAnalysis, WeakZeroSIV) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %16 "i" + OpName %31 "arr" + OpName %52 "i" + OpName %63 "arr" + OpName %82 "i" + OpName %90 "arr" + OpName %109 "i" + OpName %117 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %14 = OpTypeInt 32 1 + %15 = OpTypePointer Function %14 + %17 = OpConstant %14 0 + %24 = OpConstant %14 10 + %25 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %14 %28 + %30 = OpTypePointer Function %29 + %40 = OpConstant %14 9 + %50 = OpConstant %14 1 + %60 = OpConstant %27 11 + %61 = OpTypeArray %14 %60 + %62 = OpTypePointer Function %61 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %136 = OpFunctionCall %2 %6 + %137 = OpFunctionCall %2 %8 + %138 = OpFunctionCall %2 %10 + %139 = OpFunctionCall %2 %12 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %16 = OpVariable %15 Function + %31 = OpVariable %30 Function + OpStore %16 %17 + OpBranch %18 + %18 = OpLabel + %140 = OpPhi %14 %17 %7 %51 %21 + OpLoopMerge %20 %21 None + OpBranch %22 + %22 = OpLabel + %26 = OpSLessThan %25 %140 %24 + OpBranchConditional %26 %19 %20 + %19 = OpLabel + %33 = OpAccessChain %15 %31 %140 + %34 = OpLoad %14 %33 + %35 = OpAccessChain %15 %31 %17 + OpStore %35 %34 + %37 = OpAccessChain %15 %31 %17 + %38 = OpLoad %14 %37 + %39 = OpAccessChain %15 %31 %140 + OpStore %39 %38 + %42 = OpAccessChain %15 %31 %140 + %43 = OpLoad %14 %42 + %44 = OpAccessChain %15 %31 %40 + OpStore %44 %43 + %46 = OpAccessChain %15 %31 %40 + %47 = OpLoad %14 %46 + %48 = OpAccessChain %15 %31 %140 + OpStore %48 %47 + OpBranch %21 + %21 = OpLabel + %51 = OpIAdd %14 %140 %50 + OpStore %16 %51 + OpBranch %18 + %20 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %52 = OpVariable %15 Function + %63 = OpVariable %62 Function + OpStore %52 %17 + OpBranch %53 + %53 = OpLabel + %141 = OpPhi %14 %17 %9 %81 %56 + OpLoopMerge %55 %56 None + OpBranch %57 + %57 = OpLabel + %59 = OpSLessThanEqual %25 %141 %24 + OpBranchConditional %59 %54 %55 + %54 = OpLabel + %65 = OpAccessChain %15 %63 %141 + %66 = OpLoad %14 %65 + %67 = OpAccessChain %15 %63 %17 + OpStore %67 %66 + %69 = OpAccessChain %15 %63 %17 + %70 = OpLoad %14 %69 + %71 = OpAccessChain %15 %63 %141 + OpStore %71 %70 + %73 = OpAccessChain %15 %63 %141 + %74 = OpLoad %14 %73 + %75 = OpAccessChain %15 %63 %24 + OpStore %75 %74 + %77 = OpAccessChain %15 %63 %24 + %78 = OpLoad %14 %77 + %79 = OpAccessChain %15 %63 %141 + OpStore %79 %78 + OpBranch %56 + %56 = OpLabel + %81 = OpIAdd %14 %141 %50 + OpStore %52 %81 + OpBranch %53 + %55 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %82 = OpVariable %15 Function + %90 = OpVariable %62 Function + OpStore %82 %24 + OpBranch %83 + %83 = OpLabel + %142 = OpPhi %14 %24 %11 %108 %86 + OpLoopMerge %85 %86 None + OpBranch %87 + %87 = OpLabel + %89 = OpSGreaterThan %25 %142 %17 + OpBranchConditional %89 %84 %85 + %84 = OpLabel + %92 = OpAccessChain %15 %90 %142 + %93 = OpLoad %14 %92 + %94 = OpAccessChain %15 %90 %24 + OpStore %94 %93 + %96 = OpAccessChain %15 %90 %24 + %97 = OpLoad %14 %96 + %98 = OpAccessChain %15 %90 %142 + OpStore %98 %97 + %100 = OpAccessChain %15 %90 %142 + %101 = OpLoad %14 %100 + %102 = OpAccessChain %15 %90 %50 + OpStore %102 %101 + %104 = OpAccessChain %15 %90 %50 + %105 = OpLoad %14 %104 + %106 = OpAccessChain %15 %90 %142 + OpStore %106 %105 + OpBranch %86 + %86 = OpLabel + %108 = OpISub %14 %142 %50 + OpStore %82 %108 + OpBranch %83 + %85 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %109 = OpVariable %15 Function + %117 = OpVariable %62 Function + OpStore %109 %24 + OpBranch %110 + %110 = OpLabel + %143 = OpPhi %14 %24 %13 %135 %113 + OpLoopMerge %112 %113 None + OpBranch %114 + %114 = OpLabel + %116 = OpSGreaterThanEqual %25 %143 %17 + OpBranchConditional %116 %111 %112 + %111 = OpLabel + %119 = OpAccessChain %15 %117 %143 + %120 = OpLoad %14 %119 + %121 = OpAccessChain %15 %117 %24 + OpStore %121 %120 + %123 = OpAccessChain %15 %117 %24 + %124 = OpLoad %14 %123 + %125 = OpAccessChain %15 %117 %143 + OpStore %125 %124 + %127 = OpAccessChain %15 %117 %143 + %128 = OpLoad %14 %127 + %129 = OpAccessChain %15 %117 %17 + OpStore %129 %128 + %131 = OpAccessChain %15 %117 %17 + %132 = OpLoad %14 %131 + %133 = OpAccessChain %15 %117 %143 + OpStore %133 %132 + OpBranch %113 + %113 = OpLabel + %135 = OpISub %14 %143 %50 + OpStore %109 %135 + OpBranch %110 + %112 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + // For the loop in function a + { + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 19)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 34 -> 35 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(34), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 38 -> 39 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(38), store[1], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 43 -> 44 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(43), store[2], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 47 -> 48 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(47), store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + } + // For the loop in function b + { + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 54)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 66 -> 67 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(66), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 70 -> 71 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(70), store[1], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 74 -> 75 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(74), store[2], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 78 -> 79 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(78), store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + } + // For the loop in function c + { + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 84)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 93 -> 94 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(93), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 97 -> 98 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(97), store[1], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 101 -> 102 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(101), store[2], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 105 -> 106 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(105), store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + } + // For the loop in function d + { + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 111)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 120 -> 121 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(120), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 124 -> 125 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(124), store[1], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 128 -> 129 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(128), store[2], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 132 -> 133 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(132), store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void main(){ + int[10][10] arr; + for (int i = 0; i < 10; i++) { + arr[i][i] = arr[i][i]; + arr[0][i] = arr[1][i]; + arr[1][i] = arr[0][i]; + arr[i][0] = arr[i][1]; + arr[i][1] = arr[i][0]; + arr[0][1] = arr[1][0]; + } +} +*/ +TEST(DependencyAnalysis, MultipleSubscriptZIVSIV) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %24 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypeArray %21 %20 + %23 = OpTypePointer Function %22 + %33 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %24 = OpVariable %23 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %58 = OpPhi %6 %9 %5 %57 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %58 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %29 = OpAccessChain %7 %24 %58 %58 + %30 = OpLoad %6 %29 + %31 = OpAccessChain %7 %24 %58 %58 + OpStore %31 %30 + %35 = OpAccessChain %7 %24 %33 %58 + %36 = OpLoad %6 %35 + %37 = OpAccessChain %7 %24 %9 %58 + OpStore %37 %36 + %40 = OpAccessChain %7 %24 %9 %58 + %41 = OpLoad %6 %40 + %42 = OpAccessChain %7 %24 %33 %58 + OpStore %42 %41 + %45 = OpAccessChain %7 %24 %58 %33 + %46 = OpLoad %6 %45 + %47 = OpAccessChain %7 %24 %58 %9 + OpStore %47 %46 + %50 = OpAccessChain %7 %24 %58 %9 + %51 = OpLoad %6 %50 + %52 = OpAccessChain %7 %24 %58 %33 + OpStore %52 %51 + %53 = OpAccessChain %7 %24 %33 %9 + %54 = OpLoad %6 %53 + %55 = OpAccessChain %7 %24 %9 %33 + OpStore %55 %54 + OpBranch %13 + %13 = OpLabel + %57 = OpIAdd %6 %58 %33 + OpStore %8 %57 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[6]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 11)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 6; ++i) { + EXPECT_TRUE(store[i]); + } + + // 30 -> 31 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(30), + store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::EQ); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 0); + } + + // 36 -> 37 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(36), + store[1], &distance_vector)); + } + + // 41 -> 42 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(41), + store[2], &distance_vector)); + } + + // 46 -> 47 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(46), + store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::EQ); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 0); + } + + // 51 -> 52 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(51), + store[4], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::EQ); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 0); + } + + // 54 -> 55 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(54), + store[5], &distance_vector)); + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void a(){ + int[10] arr; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + arr[j] = arr[j]; + } + } +} +void b(){ + int[10] arr; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + arr[i] = arr[i]; + } + } +} +void main() { + a(); + b(); +} +*/ +TEST(DependencyAnalysis, IrrelevantSubscripts) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %12 "i" + OpName %23 "j" + OpName %35 "arr" + OpName %46 "i" + OpName %54 "j" + OpName %62 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %13 = OpConstant %10 0 + %20 = OpConstant %10 10 + %21 = OpTypeBool + %31 = OpTypeInt 32 0 + %32 = OpConstant %31 10 + %33 = OpTypeArray %10 %32 + %34 = OpTypePointer Function %33 + %42 = OpConstant %10 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %72 = OpFunctionCall %2 %6 + %73 = OpFunctionCall %2 %8 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %12 = OpVariable %11 Function + %23 = OpVariable %11 Function + %35 = OpVariable %34 Function + OpStore %12 %13 + OpBranch %14 + %14 = OpLabel + %74 = OpPhi %10 %13 %7 %45 %17 + OpLoopMerge %16 %17 None + OpBranch %18 + %18 = OpLabel + %22 = OpSLessThan %21 %74 %20 + OpBranchConditional %22 %15 %16 + %15 = OpLabel + OpStore %23 %13 + OpBranch %24 + %24 = OpLabel + %75 = OpPhi %10 %13 %15 %43 %27 + OpLoopMerge %26 %27 None + OpBranch %28 + %28 = OpLabel + %30 = OpSLessThan %21 %75 %20 + OpBranchConditional %30 %25 %26 + %25 = OpLabel + %38 = OpAccessChain %11 %35 %75 + %39 = OpLoad %10 %38 + %40 = OpAccessChain %11 %35 %75 + OpStore %40 %39 + OpBranch %27 + %27 = OpLabel + %43 = OpIAdd %10 %75 %42 + OpStore %23 %43 + OpBranch %24 + %26 = OpLabel + OpBranch %17 + %17 = OpLabel + %45 = OpIAdd %10 %74 %42 + OpStore %12 %45 + OpBranch %14 + %16 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %46 = OpVariable %11 Function + %54 = OpVariable %11 Function + %62 = OpVariable %34 Function + OpStore %46 %13 + OpBranch %47 + %47 = OpLabel + %77 = OpPhi %10 %13 %9 %71 %50 + OpLoopMerge %49 %50 None + OpBranch %51 + %51 = OpLabel + %53 = OpSLessThan %21 %77 %20 + OpBranchConditional %53 %48 %49 + %48 = OpLabel + OpStore %54 %13 + OpBranch %55 + %55 = OpLabel + %78 = OpPhi %10 %13 %48 %69 %58 + OpLoopMerge %57 %58 None + OpBranch %59 + %59 = OpLabel + %61 = OpSLessThan %21 %78 %20 + OpBranchConditional %61 %56 %57 + %56 = OpLabel + %65 = OpAccessChain %11 %62 %77 + %66 = OpLoad %10 %65 + %67 = OpAccessChain %11 %62 %77 + OpStore %67 %66 + OpBranch %58 + %58 = OpLabel + %69 = OpIAdd %10 %78 %42 + OpStore %54 %69 + OpBranch %55 + %57 = OpLabel + OpBranch %50 + %50 = OpLabel + %71 = OpIAdd %10 %77 %42 + OpStore %46 %71 + OpBranch %47 + %49 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + // For the loop in function a + { + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + std::vector loops{&ld.GetLoopByIndex(1), + &ld.GetLoopByIndex(0)}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[1]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 25)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 1; ++i) { + EXPECT_TRUE(store[i]); + } + + // 39 -> 40 + { + DistanceVector distance_vector{loops.size()}; + analysis.SetDebugStream(std::cout); + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(39), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::IRRELEVANT); + EXPECT_EQ(distance_vector.GetEntries()[1].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[1].distance, 0); + } + } + + // For the loop in function b + { + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + std::vector loops{&ld.GetLoopByIndex(1), + &ld.GetLoopByIndex(0)}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[1]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 56)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 1; ++i) { + EXPECT_TRUE(store[i]); + } + + // 66 -> 67 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(66), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 0); + EXPECT_EQ(distance_vector.GetEntries()[1].dependence_information, + DistanceEntry::DependenceInformation::IRRELEVANT); + } + } +} + +void CheckDependenceAndDirection(const Instruction* source, + const Instruction* destination, + bool expected_dependence, + DistanceVector expected_distance, + LoopDependenceAnalysis* analysis) { + DistanceVector dv_entry(2); + EXPECT_EQ(expected_dependence, + analysis->GetDependence(source, destination, &dv_entry)); + EXPECT_EQ(expected_distance, dv_entry); +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +layout(location = 0) in vec4 c; +void main(){ + int[10] arr; + int a = 2; + int b = 3; + int N = int(c.x); + for (int i = 0; i < 10; i++) { + for (int j = 2; j < 10; j++) { + arr[i] = arr[j]; // 0 + arr[j] = arr[i]; // 1 + arr[j-2] = arr[i+3]; // 2 + arr[j-a] = arr[i+b]; // 3 + arr[2*i] = arr[4*j+3]; // 4, independent + arr[2*i] = arr[4*j]; // 5 + arr[i+j] = arr[i+j]; // 6 + arr[10*i+j] = arr[10*i+j]; // 7 + arr[10*i+10*j] = arr[10*i+10*j+3]; // 8, independent + arr[10*i+10*j] = arr[10*i+N*j+3]; // 9, bail out because of N coefficient + arr[10*i+10*j] = arr[10*i+10*j+N]; // 10, bail out because of N constant + // term + arr[10*i+N*j] = arr[10*i+10*j+3]; // 11, bail out because of N coefficient + arr[10*i+10*j+N] = arr[10*i+10*j]; // 12, bail out because of N constant + // term + arr[10*i] = arr[5*j]; // 13, independent + arr[5*i] = arr[10*j]; // 14, independent + arr[9*i] = arr[3*j]; // 15, independent + arr[3*i] = arr[9*j]; // 16, independent + } + } +} +*/ +TEST(DependencyAnalysis, MIV) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %16 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "a" + OpName %10 "b" + OpName %12 "N" + OpName %16 "c" + OpName %23 "i" + OpName %34 "j" + OpName %45 "arr" + OpDecorate %16 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 2 + %11 = OpConstant %6 3 + %13 = OpTypeFloat 32 + %14 = OpTypeVector %13 4 + %15 = OpTypePointer Input %14 + %16 = OpVariable %15 Input + %17 = OpTypeInt 32 0 + %18 = OpConstant %17 0 + %19 = OpTypePointer Input %13 + %24 = OpConstant %6 0 + %31 = OpConstant %6 10 + %32 = OpTypeBool + %42 = OpConstant %17 10 + %43 = OpTypeArray %6 %42 + %44 = OpTypePointer Function %43 + %74 = OpConstant %6 4 + %184 = OpConstant %6 5 + %197 = OpConstant %6 9 + %213 = OpConstant %6 1 + %218 = OpUndef %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %12 = OpVariable %7 Function + %23 = OpVariable %7 Function + %34 = OpVariable %7 Function + %45 = OpVariable %44 Function + OpStore %8 %9 + OpStore %10 %11 + %20 = OpAccessChain %19 %16 %18 + %21 = OpLoad %13 %20 + %22 = OpConvertFToS %6 %21 + OpStore %12 %22 + OpStore %23 %24 + OpBranch %25 + %25 = OpLabel + %217 = OpPhi %6 %24 %5 %216 %28 + %219 = OpPhi %6 %218 %5 %220 %28 + OpLoopMerge %27 %28 None + OpBranch %29 + %29 = OpLabel + %33 = OpSLessThan %32 %217 %31 + OpBranchConditional %33 %26 %27 + %26 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %220 = OpPhi %6 %9 %26 %214 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %32 %220 %31 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %48 = OpAccessChain %7 %45 %220 + %49 = OpLoad %6 %48 + %50 = OpAccessChain %7 %45 %217 + OpStore %50 %49 + %53 = OpAccessChain %7 %45 %217 + %54 = OpLoad %6 %53 + %55 = OpAccessChain %7 %45 %220 + OpStore %55 %54 + %57 = OpISub %6 %220 %9 + %59 = OpIAdd %6 %217 %11 + %60 = OpAccessChain %7 %45 %59 + %61 = OpLoad %6 %60 + %62 = OpAccessChain %7 %45 %57 + OpStore %62 %61 + %65 = OpISub %6 %220 %9 + %68 = OpIAdd %6 %217 %11 + %69 = OpAccessChain %7 %45 %68 + %70 = OpLoad %6 %69 + %71 = OpAccessChain %7 %45 %65 + OpStore %71 %70 + %73 = OpIMul %6 %9 %217 + %76 = OpIMul %6 %74 %220 + %77 = OpIAdd %6 %76 %11 + %78 = OpAccessChain %7 %45 %77 + %79 = OpLoad %6 %78 + %80 = OpAccessChain %7 %45 %73 + OpStore %80 %79 + %82 = OpIMul %6 %9 %217 + %84 = OpIMul %6 %74 %220 + %85 = OpAccessChain %7 %45 %84 + %86 = OpLoad %6 %85 + %87 = OpAccessChain %7 %45 %82 + OpStore %87 %86 + %90 = OpIAdd %6 %217 %220 + %93 = OpIAdd %6 %217 %220 + %94 = OpAccessChain %7 %45 %93 + %95 = OpLoad %6 %94 + %96 = OpAccessChain %7 %45 %90 + OpStore %96 %95 + %98 = OpIMul %6 %31 %217 + %100 = OpIAdd %6 %98 %220 + %102 = OpIMul %6 %31 %217 + %104 = OpIAdd %6 %102 %220 + %105 = OpAccessChain %7 %45 %104 + %106 = OpLoad %6 %105 + %107 = OpAccessChain %7 %45 %100 + OpStore %107 %106 + %109 = OpIMul %6 %31 %217 + %111 = OpIMul %6 %31 %220 + %112 = OpIAdd %6 %109 %111 + %114 = OpIMul %6 %31 %217 + %116 = OpIMul %6 %31 %220 + %117 = OpIAdd %6 %114 %116 + %118 = OpIAdd %6 %117 %11 + %119 = OpAccessChain %7 %45 %118 + %120 = OpLoad %6 %119 + %121 = OpAccessChain %7 %45 %112 + OpStore %121 %120 + %123 = OpIMul %6 %31 %217 + %125 = OpIMul %6 %31 %220 + %126 = OpIAdd %6 %123 %125 + %128 = OpIMul %6 %31 %217 + %131 = OpIMul %6 %22 %220 + %132 = OpIAdd %6 %128 %131 + %133 = OpIAdd %6 %132 %11 + %134 = OpAccessChain %7 %45 %133 + %135 = OpLoad %6 %134 + %136 = OpAccessChain %7 %45 %126 + OpStore %136 %135 + %138 = OpIMul %6 %31 %217 + %140 = OpIMul %6 %31 %220 + %141 = OpIAdd %6 %138 %140 + %143 = OpIMul %6 %31 %217 + %145 = OpIMul %6 %31 %220 + %146 = OpIAdd %6 %143 %145 + %148 = OpIAdd %6 %146 %22 + %149 = OpAccessChain %7 %45 %148 + %150 = OpLoad %6 %149 + %151 = OpAccessChain %7 %45 %141 + OpStore %151 %150 + %153 = OpIMul %6 %31 %217 + %156 = OpIMul %6 %22 %220 + %157 = OpIAdd %6 %153 %156 + %159 = OpIMul %6 %31 %217 + %161 = OpIMul %6 %31 %220 + %162 = OpIAdd %6 %159 %161 + %163 = OpIAdd %6 %162 %11 + %164 = OpAccessChain %7 %45 %163 + %165 = OpLoad %6 %164 + %166 = OpAccessChain %7 %45 %157 + OpStore %166 %165 + %168 = OpIMul %6 %31 %217 + %170 = OpIMul %6 %31 %220 + %171 = OpIAdd %6 %168 %170 + %173 = OpIAdd %6 %171 %22 + %175 = OpIMul %6 %31 %217 + %177 = OpIMul %6 %31 %220 + %178 = OpIAdd %6 %175 %177 + %179 = OpAccessChain %7 %45 %178 + %180 = OpLoad %6 %179 + %181 = OpAccessChain %7 %45 %173 + OpStore %181 %180 + %183 = OpIMul %6 %31 %217 + %186 = OpIMul %6 %184 %220 + %187 = OpAccessChain %7 %45 %186 + %188 = OpLoad %6 %187 + %189 = OpAccessChain %7 %45 %183 + OpStore %189 %188 + %191 = OpIMul %6 %184 %217 + %193 = OpIMul %6 %31 %220 + %194 = OpAccessChain %7 %45 %193 + %195 = OpLoad %6 %194 + %196 = OpAccessChain %7 %45 %191 + OpStore %196 %195 + %199 = OpIMul %6 %197 %217 + %201 = OpIMul %6 %11 %220 + %202 = OpAccessChain %7 %45 %201 + %203 = OpLoad %6 %202 + %204 = OpAccessChain %7 %45 %199 + OpStore %204 %203 + %206 = OpIMul %6 %11 %217 + %208 = OpIMul %6 %197 %220 + %209 = OpAccessChain %7 %45 %208 + %210 = OpLoad %6 %209 + %211 = OpAccessChain %7 %45 %206 + OpStore %211 %210 + OpBranch %38 + %38 = OpLabel + %214 = OpIAdd %6 %220 %213 + OpStore %34 %214 + OpBranch %35 + %37 = OpLabel + OpBranch %28 + %28 = OpLabel + %216 = OpIAdd %6 %217 %213 + OpStore %23 %216 + OpBranch %25 + %27 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + std::vector loops{&ld.GetLoopByIndex(0), &ld.GetLoopByIndex(1)}; + + LoopDependenceAnalysis analysis{context.get(), loops}; + + const int instructions_expected = 17; + const Instruction* store[instructions_expected]; + const Instruction* load[instructions_expected]; + int stores_found = 0; + int loads_found = 0; + + int block_id = 36; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load[loads_found] = &inst; + ++loads_found; + } + } + + EXPECT_EQ(instructions_expected, stores_found); + EXPECT_EQ(instructions_expected, loads_found); + + auto directions_all = DistanceEntry(DistanceEntry::Directions::ALL); + auto directions_none = DistanceEntry(DistanceEntry::Directions::NONE); + + auto dependent = DistanceVector({directions_all, directions_all}); + auto independent = DistanceVector({directions_none, directions_none}); + + CheckDependenceAndDirection(load[0], store[0], false, dependent, &analysis); + CheckDependenceAndDirection(load[1], store[1], false, dependent, &analysis); + CheckDependenceAndDirection(load[2], store[2], false, dependent, &analysis); + CheckDependenceAndDirection(load[3], store[3], false, dependent, &analysis); + CheckDependenceAndDirection(load[4], store[4], true, independent, &analysis); + CheckDependenceAndDirection(load[5], store[5], false, dependent, &analysis); + CheckDependenceAndDirection(load[6], store[6], false, dependent, &analysis); + CheckDependenceAndDirection(load[7], store[7], false, dependent, &analysis); + CheckDependenceAndDirection(load[8], store[8], true, independent, &analysis); + CheckDependenceAndDirection(load[9], store[9], false, dependent, &analysis); + CheckDependenceAndDirection(load[10], store[10], false, dependent, &analysis); + CheckDependenceAndDirection(load[11], store[11], false, dependent, &analysis); + CheckDependenceAndDirection(load[12], store[12], false, dependent, &analysis); + CheckDependenceAndDirection(load[13], store[13], true, independent, + &analysis); + CheckDependenceAndDirection(load[14], store[14], true, independent, + &analysis); + CheckDependenceAndDirection(load[15], store[15], true, independent, + &analysis); + CheckDependenceAndDirection(load[16], store[16], true, independent, + &analysis); +} + +void PartitionSubscripts(const Instruction* instruction_0, + const Instruction* instruction_1, + LoopDependenceAnalysis* analysis, + std::vector> expected_ids) { + auto subscripts_0 = analysis->GetSubscripts(instruction_0); + auto subscripts_1 = analysis->GetSubscripts(instruction_1); + + std::vector>> + expected_partition{}; + + for (const auto& partition : expected_ids) { + expected_partition.push_back( + std::set>{}); + for (auto id : partition) { + expected_partition.back().insert({subscripts_0[id], subscripts_1[id]}); + } + } + + EXPECT_EQ(expected_partition, + analysis->PartitionSubscripts(subscripts_0, subscripts_1)); +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void main(){ + int[10][10][10][10] arr; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + for (int l = 0; l < 10; l++) { + arr[i][j][k][l] = arr[i][j][k][l]; // 0, all independent + arr[i][j][k][l] = arr[i][j][l][0]; // 1, last 2 coupled + arr[i][j][k][l] = arr[j][i][k][l]; // 2, first 2 coupled + arr[i][j][k][l] = arr[l][j][k][i]; // 3, first & last coupled + arr[i][j][k][l] = arr[i][k][j][l]; // 4, middle 2 coupled + arr[i+j][j][k][l] = arr[i][j][k][l]; // 5, first 2 coupled + arr[i+j+k][j][k][l] = arr[i][j][k][l]; // 6, first 3 coupled + arr[i+j+k+l][j][k][l] = arr[i][j][k][l]; // 7, all 4 coupled + arr[i][j][k][l] = arr[i][l][j][k]; // 8, last 3 coupled + arr[i][j-k][k][l] = arr[i][j][l][k]; // 9, last 3 coupled + arr[i][j][k][l] = arr[l][i][j][k]; // 10, all 4 coupled + arr[i][j][k][l] = arr[j][i][l][k]; // 11, 2 coupled partitions (i,j) & +(l&k) + arr[i][j][k][l] = arr[k][l][i][j]; // 12, 2 coupled partitions (i,k) & +(j&l) + } + } + } + } +} +*/ +TEST(DependencyAnalysis, SubscriptPartitioning) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %27 "k" + OpName %35 "l" + OpName %50 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %43 = OpTypeInt 32 0 + %44 = OpConstant %43 10 + %45 = OpTypeArray %6 %44 + %46 = OpTypeArray %45 %44 + %47 = OpTypeArray %46 %44 + %48 = OpTypeArray %47 %44 + %49 = OpTypePointer Function %48 + %208 = OpConstant %6 1 + %217 = OpUndef %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %27 = OpVariable %7 Function + %35 = OpVariable %7 Function + %50 = OpVariable %49 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %216 = OpPhi %6 %9 %5 %215 %13 + %218 = OpPhi %6 %217 %5 %221 %13 + %219 = OpPhi %6 %217 %5 %222 %13 + %220 = OpPhi %6 %217 %5 %223 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %216 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %221 = OpPhi %6 %9 %11 %213 %23 + %222 = OpPhi %6 %219 %11 %224 %23 + %223 = OpPhi %6 %220 %11 %225 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %221 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + OpStore %27 %9 + OpBranch %28 + %28 = OpLabel + %224 = OpPhi %6 %9 %21 %211 %31 + %225 = OpPhi %6 %223 %21 %226 %31 + OpLoopMerge %30 %31 None + OpBranch %32 + %32 = OpLabel + %34 = OpSLessThan %17 %224 %16 + OpBranchConditional %34 %29 %30 + %29 = OpLabel + OpStore %35 %9 + OpBranch %36 + %36 = OpLabel + %226 = OpPhi %6 %9 %29 %209 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %17 %226 %16 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %59 = OpAccessChain %7 %50 %216 %221 %224 %226 + %60 = OpLoad %6 %59 + %61 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %61 %60 + %69 = OpAccessChain %7 %50 %216 %221 %226 %9 + %70 = OpLoad %6 %69 + %71 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %71 %70 + %80 = OpAccessChain %7 %50 %221 %216 %224 %226 + %81 = OpLoad %6 %80 + %82 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %82 %81 + %91 = OpAccessChain %7 %50 %226 %221 %224 %216 + %92 = OpLoad %6 %91 + %93 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %93 %92 + %102 = OpAccessChain %7 %50 %216 %224 %221 %226 + %103 = OpLoad %6 %102 + %104 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %104 %103 + %107 = OpIAdd %6 %216 %221 + %115 = OpAccessChain %7 %50 %216 %221 %224 %226 + %116 = OpLoad %6 %115 + %117 = OpAccessChain %7 %50 %107 %221 %224 %226 + OpStore %117 %116 + %120 = OpIAdd %6 %216 %221 + %122 = OpIAdd %6 %120 %224 + %130 = OpAccessChain %7 %50 %216 %221 %224 %226 + %131 = OpLoad %6 %130 + %132 = OpAccessChain %7 %50 %122 %221 %224 %226 + OpStore %132 %131 + %135 = OpIAdd %6 %216 %221 + %137 = OpIAdd %6 %135 %224 + %139 = OpIAdd %6 %137 %226 + %147 = OpAccessChain %7 %50 %216 %221 %224 %226 + %148 = OpLoad %6 %147 + %149 = OpAccessChain %7 %50 %139 %221 %224 %226 + OpStore %149 %148 + %158 = OpAccessChain %7 %50 %216 %226 %221 %224 + %159 = OpLoad %6 %158 + %160 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %160 %159 + %164 = OpISub %6 %221 %224 + %171 = OpAccessChain %7 %50 %216 %221 %226 %224 + %172 = OpLoad %6 %171 + %173 = OpAccessChain %7 %50 %216 %164 %224 %226 + OpStore %173 %172 + %182 = OpAccessChain %7 %50 %226 %216 %221 %224 + %183 = OpLoad %6 %182 + %184 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %184 %183 + %193 = OpAccessChain %7 %50 %221 %216 %226 %224 + %194 = OpLoad %6 %193 + %195 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %195 %194 + %204 = OpAccessChain %7 %50 %224 %226 %216 %221 + %205 = OpLoad %6 %204 + %206 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %206 %205 + OpBranch %39 + %39 = OpLabel + %209 = OpIAdd %6 %226 %208 + OpStore %35 %209 + OpBranch %36 + %38 = OpLabel + OpBranch %31 + %31 = OpLabel + %211 = OpIAdd %6 %224 %208 + OpStore %27 %211 + OpBranch %28 + %30 = OpLabel + OpBranch %23 + %23 = OpLabel + %213 = OpIAdd %6 %221 %208 + OpStore %19 %213 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %215 = OpIAdd %6 %216 %208 + OpStore %8 %215 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + std::vector loop_nest{ + &ld.GetLoopByIndex(0), &ld.GetLoopByIndex(1), &ld.GetLoopByIndex(2), + &ld.GetLoopByIndex(3)}; + LoopDependenceAnalysis analysis{context.get(), loop_nest}; + + const int instructions_expected = 13; + const Instruction* store[instructions_expected]; + const Instruction* load[instructions_expected]; + int stores_found = 0; + int loads_found = 0; + + int block_id = 37; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load[loads_found] = &inst; + ++loads_found; + } + } + + EXPECT_EQ(instructions_expected, stores_found); + EXPECT_EQ(instructions_expected, loads_found); + + PartitionSubscripts(load[0], store[0], &analysis, {{0}, {1}, {2}, {3}}); + PartitionSubscripts(load[1], store[1], &analysis, {{0}, {1}, {2, 3}}); + PartitionSubscripts(load[2], store[2], &analysis, {{0, 1}, {2}, {3}}); + PartitionSubscripts(load[3], store[3], &analysis, {{0, 3}, {1}, {2}}); + PartitionSubscripts(load[4], store[4], &analysis, {{0}, {1, 2}, {3}}); + PartitionSubscripts(load[5], store[5], &analysis, {{0, 1}, {2}, {3}}); + PartitionSubscripts(load[6], store[6], &analysis, {{0, 1, 2}, {3}}); + PartitionSubscripts(load[7], store[7], &analysis, {{0, 1, 2, 3}}); + PartitionSubscripts(load[8], store[8], &analysis, {{0}, {1, 2, 3}}); + PartitionSubscripts(load[9], store[9], &analysis, {{0}, {1, 2, 3}}); + PartitionSubscripts(load[10], store[10], &analysis, {{0, 1, 2, 3}}); + PartitionSubscripts(load[11], store[11], &analysis, {{0, 1}, {2, 3}}); + PartitionSubscripts(load[12], store[12], &analysis, {{0, 2}, {1, 3}}); +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store + +#version 440 core +void a() { + int[10][10] arr; + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 10; ++j) { + // Dependent, distance vector (1, -1) + arr[i+1][i+j] = arr[i][i+j]; + } + } +} + +void b() { + int[10][10] arr; + for (int i = 0; i < 10; ++i) { + // Independent + arr[i+1][i+2] = arr[i][i] + 2; + } +} + +void c() { + int[10][10] arr; + for (int i = 0; i < 10; ++i) { + // Dependence point (1,2) + arr[i][i] = arr[1][i-1] + 2; + } +} + +void d() { + int[10][10][10] arr; + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 10; ++j) { + for (int k = 0; k < 10; ++k) { + // Dependent, distance vector (1,1,-1) + arr[j-i][i+1][j+k] = arr[j-i][i][j+k]; + } + } + } +} + +void e() { + int[10][10] arr; + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 10; ++j) { + // Independent with GCD after propagation + arr[i][2*j+i] = arr[i][2*j-i+5]; + } + } +} + +void main(){ + a(); + b(); + c(); + d(); + e(); +} +*/ +TEST(DependencyAnalysis, Delta) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %14 "e(" + OpName %18 "i" + OpName %29 "j" + OpName %42 "arr" + OpName %60 "i" + OpName %68 "arr" + OpName %82 "i" + OpName %90 "arr" + OpName %101 "i" + OpName %109 "j" + OpName %117 "k" + OpName %127 "arr" + OpName %152 "i" + OpName %160 "j" + OpName %168 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %16 = OpTypeInt 32 1 + %17 = OpTypePointer Function %16 + %19 = OpConstant %16 0 + %26 = OpConstant %16 10 + %27 = OpTypeBool + %37 = OpTypeInt 32 0 + %38 = OpConstant %37 10 + %39 = OpTypeArray %16 %38 + %40 = OpTypeArray %39 %38 + %41 = OpTypePointer Function %40 + %44 = OpConstant %16 1 + %72 = OpConstant %16 2 + %125 = OpTypeArray %40 %38 + %126 = OpTypePointer Function %125 + %179 = OpConstant %16 5 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %188 = OpFunctionCall %2 %6 + %189 = OpFunctionCall %2 %8 + %190 = OpFunctionCall %2 %10 + %191 = OpFunctionCall %2 %12 + %192 = OpFunctionCall %2 %14 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %18 = OpVariable %17 Function + %29 = OpVariable %17 Function + %42 = OpVariable %41 Function + OpStore %18 %19 + OpBranch %20 + %20 = OpLabel + %193 = OpPhi %16 %19 %7 %59 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %28 = OpSLessThan %27 %193 %26 + OpBranchConditional %28 %21 %22 + %21 = OpLabel + OpStore %29 %19 + OpBranch %30 + %30 = OpLabel + %194 = OpPhi %16 %19 %21 %57 %33 + OpLoopMerge %32 %33 None + OpBranch %34 + %34 = OpLabel + %36 = OpSLessThan %27 %194 %26 + OpBranchConditional %36 %31 %32 + %31 = OpLabel + %45 = OpIAdd %16 %193 %44 + %48 = OpIAdd %16 %193 %194 + %52 = OpIAdd %16 %193 %194 + %53 = OpAccessChain %17 %42 %193 %52 + %54 = OpLoad %16 %53 + %55 = OpAccessChain %17 %42 %45 %48 + OpStore %55 %54 + OpBranch %33 + %33 = OpLabel + %57 = OpIAdd %16 %194 %44 + OpStore %29 %57 + OpBranch %30 + %32 = OpLabel + OpBranch %23 + %23 = OpLabel + %59 = OpIAdd %16 %193 %44 + OpStore %18 %59 + OpBranch %20 + %22 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %60 = OpVariable %17 Function + %68 = OpVariable %41 Function + OpStore %60 %19 + OpBranch %61 + %61 = OpLabel + %196 = OpPhi %16 %19 %9 %81 %64 + OpLoopMerge %63 %64 None + OpBranch %65 + %65 = OpLabel + %67 = OpSLessThan %27 %196 %26 + OpBranchConditional %67 %62 %63 + %62 = OpLabel + %70 = OpIAdd %16 %196 %44 + %73 = OpIAdd %16 %196 %72 + %76 = OpAccessChain %17 %68 %196 %196 + %77 = OpLoad %16 %76 + %78 = OpIAdd %16 %77 %72 + %79 = OpAccessChain %17 %68 %70 %73 + OpStore %79 %78 + OpBranch %64 + %64 = OpLabel + %81 = OpIAdd %16 %196 %44 + OpStore %60 %81 + OpBranch %61 + %63 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %82 = OpVariable %17 Function + %90 = OpVariable %41 Function + OpStore %82 %19 + OpBranch %83 + %83 = OpLabel + %197 = OpPhi %16 %19 %11 %100 %86 + OpLoopMerge %85 %86 None + OpBranch %87 + %87 = OpLabel + %89 = OpSLessThan %27 %197 %26 + OpBranchConditional %89 %84 %85 + %84 = OpLabel + %94 = OpISub %16 %197 %44 + %95 = OpAccessChain %17 %90 %44 %94 + %96 = OpLoad %16 %95 + %97 = OpIAdd %16 %96 %72 + %98 = OpAccessChain %17 %90 %197 %197 + OpStore %98 %97 + OpBranch %86 + %86 = OpLabel + %100 = OpIAdd %16 %197 %44 + OpStore %82 %100 + OpBranch %83 + %85 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %101 = OpVariable %17 Function + %109 = OpVariable %17 Function + %117 = OpVariable %17 Function + %127 = OpVariable %126 Function + OpStore %101 %19 + OpBranch %102 + %102 = OpLabel + %198 = OpPhi %16 %19 %13 %151 %105 + OpLoopMerge %104 %105 None + OpBranch %106 + %106 = OpLabel + %108 = OpSLessThan %27 %198 %26 + OpBranchConditional %108 %103 %104 + %103 = OpLabel + OpStore %109 %19 + OpBranch %110 + %110 = OpLabel + %199 = OpPhi %16 %19 %103 %149 %113 + OpLoopMerge %112 %113 None + OpBranch %114 + %114 = OpLabel + %116 = OpSLessThan %27 %199 %26 + OpBranchConditional %116 %111 %112 + %111 = OpLabel + OpStore %117 %19 + OpBranch %118 + %118 = OpLabel + %201 = OpPhi %16 %19 %111 %147 %121 + OpLoopMerge %120 %121 None + OpBranch %122 + %122 = OpLabel + %124 = OpSLessThan %27 %201 %26 + OpBranchConditional %124 %119 %120 + %119 = OpLabel + %130 = OpISub %16 %199 %198 + %132 = OpIAdd %16 %198 %44 + %135 = OpIAdd %16 %199 %201 + %138 = OpISub %16 %199 %198 + %142 = OpIAdd %16 %199 %201 + %143 = OpAccessChain %17 %127 %138 %198 %142 + %144 = OpLoad %16 %143 + %145 = OpAccessChain %17 %127 %130 %132 %135 + OpStore %145 %144 + OpBranch %121 + %121 = OpLabel + %147 = OpIAdd %16 %201 %44 + OpStore %117 %147 + OpBranch %118 + %120 = OpLabel + OpBranch %113 + %113 = OpLabel + %149 = OpIAdd %16 %199 %44 + OpStore %109 %149 + OpBranch %110 + %112 = OpLabel + OpBranch %105 + %105 = OpLabel + %151 = OpIAdd %16 %198 %44 + OpStore %101 %151 + OpBranch %102 + %104 = OpLabel + OpReturn + OpFunctionEnd + %14 = OpFunction %2 None %3 + %15 = OpLabel + %152 = OpVariable %17 Function + %160 = OpVariable %17 Function + %168 = OpVariable %41 Function + OpStore %152 %19 + OpBranch %153 + %153 = OpLabel + %204 = OpPhi %16 %19 %15 %187 %156 + OpLoopMerge %155 %156 None + OpBranch %157 + %157 = OpLabel + %159 = OpSLessThan %27 %204 %26 + OpBranchConditional %159 %154 %155 + %154 = OpLabel + OpStore %160 %19 + OpBranch %161 + %161 = OpLabel + %205 = OpPhi %16 %19 %154 %185 %164 + OpLoopMerge %163 %164 None + OpBranch %165 + %165 = OpLabel + %167 = OpSLessThan %27 %205 %26 + OpBranchConditional %167 %162 %163 + %162 = OpLabel + %171 = OpIMul %16 %72 %205 + %173 = OpIAdd %16 %171 %204 + %176 = OpIMul %16 %72 %205 + %178 = OpISub %16 %176 %204 + %180 = OpIAdd %16 %178 %179 + %181 = OpAccessChain %17 %168 %204 %180 + %182 = OpLoad %16 %181 + %183 = OpAccessChain %17 %168 %204 %173 + OpStore %183 %182 + OpBranch %164 + %164 = OpLabel + %185 = OpIAdd %16 %205 %44 + OpStore %160 %185 + OpBranch %161 + %163 = OpLabel + OpBranch %156 + %156 = OpLabel + %187 = OpIAdd %16 %204 %44 + OpStore %152 %187 + OpBranch %153 + %155 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + { + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + const Instruction* store = nullptr; + const Instruction* load = nullptr; + + int block_id = 31; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load = &inst; + } + } + + EXPECT_NE(nullptr, store); + EXPECT_NE(nullptr, load); + + std::vector loop_nest{&ld.GetLoopByIndex(0), + &ld.GetLoopByIndex(1)}; + LoopDependenceAnalysis analysis{context.get(), loop_nest}; + + DistanceVector dv_entry(loop_nest.size()); + + std::vector expected_entries{ + DistanceEntry(DistanceEntry::Directions::LT, 1), + DistanceEntry(DistanceEntry::Directions::LT, 1)}; + + DistanceVector expected_distance_vector(expected_entries); + + auto is_independent = analysis.GetDependence(load, store, &dv_entry); + + EXPECT_FALSE(is_independent); + EXPECT_EQ(expected_distance_vector, dv_entry); + } + + { + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + const Instruction* store = nullptr; + const Instruction* load = nullptr; + + int block_id = 62; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load = &inst; + } + } + + EXPECT_NE(nullptr, store); + EXPECT_NE(nullptr, load); + + std::vector loop_nest{&ld.GetLoopByIndex(0)}; + LoopDependenceAnalysis analysis{context.get(), loop_nest}; + + DistanceVector dv_entry(loop_nest.size()); + auto is_independent = analysis.GetDependence(load, store, &dv_entry); + + EXPECT_TRUE(is_independent); + } + + { + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + const Instruction* store = nullptr; + const Instruction* load = nullptr; + + int block_id = 84; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load = &inst; + } + } + + EXPECT_NE(nullptr, store); + EXPECT_NE(nullptr, load); + + std::vector loop_nest{&ld.GetLoopByIndex(0)}; + LoopDependenceAnalysis analysis{context.get(), loop_nest}; + + DistanceVector dv_entry(loop_nest.size()); + auto is_independent = analysis.GetDependence(load, store, &dv_entry); + + DistanceVector expected_distance_vector({DistanceEntry(1, 2)}); + + EXPECT_FALSE(is_independent); + EXPECT_EQ(expected_distance_vector, dv_entry); + } + + { + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + const Instruction* store = nullptr; + const Instruction* load = nullptr; + + int block_id = 119; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load = &inst; + } + } + + EXPECT_NE(nullptr, store); + EXPECT_NE(nullptr, load); + + std::vector loop_nest{ + &ld.GetLoopByIndex(0), &ld.GetLoopByIndex(1), &ld.GetLoopByIndex(2)}; + LoopDependenceAnalysis analysis{context.get(), loop_nest}; + + DistanceVector dv_entry(loop_nest.size()); + + std::vector expected_entries{ + DistanceEntry(DistanceEntry::Directions::LT, 1), + DistanceEntry(DistanceEntry::Directions::LT, 1), + DistanceEntry(DistanceEntry::Directions::GT, -1)}; + + DistanceVector expected_distance_vector(expected_entries); + + auto is_independent = analysis.GetDependence(store, load, &dv_entry); + + EXPECT_FALSE(is_independent); + EXPECT_EQ(expected_distance_vector, dv_entry); + } + + { + const Function* f = spvtest::GetFunction(module, 14); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + const Instruction* store = nullptr; + const Instruction* load = nullptr; + + int block_id = 162; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load = &inst; + } + } + + EXPECT_NE(nullptr, store); + EXPECT_NE(nullptr, load); + + std::vector loop_nest{&ld.GetLoopByIndex(0), + &ld.GetLoopByIndex(1)}; + LoopDependenceAnalysis analysis{context.get(), loop_nest}; + + DistanceVector dv_entry(loop_nest.size()); + auto is_independent = analysis.GetDependence(load, store, &dv_entry); + + EXPECT_TRUE(is_independent); + } +} + +TEST(DependencyAnalysis, ConstraintIntersection) { + LoopDependenceAnalysis analysis{nullptr, std::vector{}}; + auto scalar_evolution = analysis.GetScalarEvolution(); + { + // One is none. Other should be returned + auto none = analysis.make_constraint(); + auto x = scalar_evolution->CreateConstant(1); + auto y = scalar_evolution->CreateConstant(10); + auto point = analysis.make_constraint(x, y, nullptr); + + auto ret_0 = analysis.IntersectConstraints(none, point, nullptr, nullptr); + + auto ret_point_0 = ret_0->AsDependencePoint(); + ASSERT_NE(nullptr, ret_point_0); + EXPECT_EQ(*x, *ret_point_0->GetSource()); + EXPECT_EQ(*y, *ret_point_0->GetDestination()); + + auto ret_1 = analysis.IntersectConstraints(point, none, nullptr, nullptr); + + auto ret_point_1 = ret_1->AsDependencePoint(); + ASSERT_NE(nullptr, ret_point_1); + EXPECT_EQ(*x, *ret_point_1->GetSource()); + EXPECT_EQ(*y, *ret_point_1->GetDestination()); + } + + { + // Both distances + auto x = scalar_evolution->CreateConstant(1); + auto y = scalar_evolution->CreateConstant(10); + + auto distance_0 = analysis.make_constraint(x, nullptr); + auto distance_1 = analysis.make_constraint(y, nullptr); + + // Equal distances + auto ret_0 = + analysis.IntersectConstraints(distance_1, distance_1, nullptr, nullptr); + + auto ret_distance = ret_0->AsDependenceDistance(); + ASSERT_NE(nullptr, ret_distance); + EXPECT_EQ(*y, *ret_distance->GetDistance()); + + // Non-equal distances + auto ret_1 = + analysis.IntersectConstraints(distance_0, distance_1, nullptr, nullptr); + EXPECT_NE(nullptr, ret_1->AsDependenceEmpty()); + } + + { + // Both points + auto x = scalar_evolution->CreateConstant(1); + auto y = scalar_evolution->CreateConstant(10); + + auto point_0 = analysis.make_constraint(x, y, nullptr); + auto point_1 = analysis.make_constraint(x, y, nullptr); + auto point_2 = analysis.make_constraint(y, y, nullptr); + + // Equal points + auto ret_0 = + analysis.IntersectConstraints(point_0, point_1, nullptr, nullptr); + auto ret_point_0 = ret_0->AsDependencePoint(); + ASSERT_NE(nullptr, ret_point_0); + EXPECT_EQ(*x, *ret_point_0->GetSource()); + EXPECT_EQ(*y, *ret_point_0->GetDestination()); + + // Non-equal points + auto ret_1 = + analysis.IntersectConstraints(point_0, point_2, nullptr, nullptr); + EXPECT_NE(nullptr, ret_1->AsDependenceEmpty()); + } + + { + // Both lines, parallel + auto a0 = scalar_evolution->CreateConstant(3); + auto b0 = scalar_evolution->CreateConstant(6); + auto c0 = scalar_evolution->CreateConstant(9); + + auto a1 = scalar_evolution->CreateConstant(6); + auto b1 = scalar_evolution->CreateConstant(12); + auto c1 = scalar_evolution->CreateConstant(18); + + auto line_0 = analysis.make_constraint(a0, b0, c0, nullptr); + auto line_1 = analysis.make_constraint(a1, b1, c1, nullptr); + + // Same line, both ways + auto ret_0 = + analysis.IntersectConstraints(line_0, line_1, nullptr, nullptr); + auto ret_1 = + analysis.IntersectConstraints(line_1, line_0, nullptr, nullptr); + + auto ret_line_0 = ret_0->AsDependenceLine(); + auto ret_line_1 = ret_1->AsDependenceLine(); + + EXPECT_NE(nullptr, ret_line_0); + EXPECT_NE(nullptr, ret_line_1); + + // Non-intersecting parallel lines + auto c2 = scalar_evolution->CreateConstant(12); + auto line_2 = analysis.make_constraint(a1, b1, c2, nullptr); + + auto ret_2 = + analysis.IntersectConstraints(line_0, line_2, nullptr, nullptr); + auto ret_3 = + analysis.IntersectConstraints(line_2, line_0, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_2->AsDependenceEmpty()); + EXPECT_NE(nullptr, ret_3->AsDependenceEmpty()); + + auto c3 = scalar_evolution->CreateConstant(20); + auto line_3 = analysis.make_constraint(a1, b1, c3, nullptr); + + auto ret_4 = + analysis.IntersectConstraints(line_0, line_3, nullptr, nullptr); + auto ret_5 = + analysis.IntersectConstraints(line_3, line_0, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_4->AsDependenceEmpty()); + EXPECT_NE(nullptr, ret_5->AsDependenceEmpty()); + } + + { + // Non-constant line + auto unknown = scalar_evolution->CreateCantComputeNode(); + auto constant = scalar_evolution->CreateConstant(10); + + auto line_0 = analysis.make_constraint(constant, constant, + constant, nullptr); + auto line_1 = analysis.make_constraint(unknown, unknown, + unknown, nullptr); + + auto ret_0 = + analysis.IntersectConstraints(line_0, line_1, nullptr, nullptr); + auto ret_1 = + analysis.IntersectConstraints(line_1, line_0, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_0->AsDependenceNone()); + EXPECT_NE(nullptr, ret_1->AsDependenceNone()); + } + + { + auto bound_0 = scalar_evolution->CreateConstant(0); + auto bound_1 = scalar_evolution->CreateConstant(20); + + auto a0 = scalar_evolution->CreateConstant(1); + auto b0 = scalar_evolution->CreateConstant(2); + auto c0 = scalar_evolution->CreateConstant(6); + + auto a1 = scalar_evolution->CreateConstant(-1); + auto b1 = scalar_evolution->CreateConstant(2); + auto c1 = scalar_evolution->CreateConstant(2); + + auto line_0 = analysis.make_constraint(a0, b0, c0, nullptr); + auto line_1 = analysis.make_constraint(a1, b1, c1, nullptr); + + // Intersecting lines, has integer solution, in bounds + auto ret_0 = + analysis.IntersectConstraints(line_0, line_1, bound_0, bound_1); + auto ret_1 = + analysis.IntersectConstraints(line_1, line_0, bound_0, bound_1); + + auto ret_point_0 = ret_0->AsDependencePoint(); + auto ret_point_1 = ret_1->AsDependencePoint(); + + EXPECT_NE(nullptr, ret_point_0); + EXPECT_NE(nullptr, ret_point_1); + + auto const_2 = scalar_evolution->CreateConstant(2); + + EXPECT_EQ(*const_2, *ret_point_0->GetSource()); + EXPECT_EQ(*const_2, *ret_point_0->GetDestination()); + + EXPECT_EQ(*const_2, *ret_point_1->GetSource()); + EXPECT_EQ(*const_2, *ret_point_1->GetDestination()); + + // Intersecting lines, has integer solution, out of bounds + auto ret_2 = + analysis.IntersectConstraints(line_0, line_1, bound_0, bound_0); + auto ret_3 = + analysis.IntersectConstraints(line_1, line_0, bound_0, bound_0); + + EXPECT_NE(nullptr, ret_2->AsDependenceEmpty()); + EXPECT_NE(nullptr, ret_3->AsDependenceEmpty()); + + auto a2 = scalar_evolution->CreateConstant(-4); + auto b2 = scalar_evolution->CreateConstant(1); + auto c2 = scalar_evolution->CreateConstant(0); + + auto a3 = scalar_evolution->CreateConstant(4); + auto b3 = scalar_evolution->CreateConstant(1); + auto c3 = scalar_evolution->CreateConstant(4); + + auto line_2 = analysis.make_constraint(a2, b2, c2, nullptr); + auto line_3 = analysis.make_constraint(a3, b3, c3, nullptr); + + // Intersecting, no integer solution + auto ret_4 = + analysis.IntersectConstraints(line_2, line_3, bound_0, bound_1); + auto ret_5 = + analysis.IntersectConstraints(line_3, line_2, bound_0, bound_1); + + EXPECT_NE(nullptr, ret_4->AsDependenceEmpty()); + EXPECT_NE(nullptr, ret_5->AsDependenceEmpty()); + + auto unknown = scalar_evolution->CreateCantComputeNode(); + + // Non-constant bound + auto ret_6 = + analysis.IntersectConstraints(line_0, line_1, unknown, bound_1); + auto ret_7 = + analysis.IntersectConstraints(line_1, line_0, bound_0, unknown); + + EXPECT_NE(nullptr, ret_6->AsDependenceNone()); + EXPECT_NE(nullptr, ret_7->AsDependenceNone()); + } + + { + auto constant_0 = scalar_evolution->CreateConstant(0); + auto constant_1 = scalar_evolution->CreateConstant(1); + auto constant_neg_1 = scalar_evolution->CreateConstant(-1); + auto constant_2 = scalar_evolution->CreateConstant(2); + auto constant_neg_2 = scalar_evolution->CreateConstant(-2); + + auto point_0_0 = analysis.make_constraint( + constant_0, constant_0, nullptr); + auto point_0_1 = analysis.make_constraint( + constant_0, constant_1, nullptr); + auto point_1_0 = analysis.make_constraint( + constant_1, constant_0, nullptr); + auto point_1_1 = analysis.make_constraint( + constant_1, constant_1, nullptr); + auto point_1_2 = analysis.make_constraint( + constant_1, constant_2, nullptr); + auto point_1_neg_1 = analysis.make_constraint( + constant_1, constant_neg_1, nullptr); + auto point_neg_1_1 = analysis.make_constraint( + constant_neg_1, constant_1, nullptr); + + auto line_y_0 = analysis.make_constraint( + constant_0, constant_1, constant_0, nullptr); + auto line_y_1 = analysis.make_constraint( + constant_0, constant_1, constant_1, nullptr); + auto line_y_2 = analysis.make_constraint( + constant_0, constant_1, constant_2, nullptr); + + // Parallel horizontal lines, y = 0 & y = 1, should return no intersection + auto ret = + analysis.IntersectConstraints(line_y_0, line_y_1, nullptr, nullptr); + + EXPECT_NE(nullptr, ret->AsDependenceEmpty()); + + // Parallel horizontal lines, y = 1 & y = 2, should return no intersection + auto ret_y_12 = + analysis.IntersectConstraints(line_y_1, line_y_2, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_y_12->AsDependenceEmpty()); + + // Same horizontal lines, y = 0 & y = 0, should return the line + auto ret_y_same_0 = + analysis.IntersectConstraints(line_y_0, line_y_0, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_y_same_0->AsDependenceLine()); + + // Same horizontal lines, y = 1 & y = 1, should return the line + auto ret_y_same_1 = + analysis.IntersectConstraints(line_y_1, line_y_1, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_y_same_1->AsDependenceLine()); + + auto line_x_0 = analysis.make_constraint( + constant_1, constant_0, constant_0, nullptr); + auto line_x_1 = analysis.make_constraint( + constant_1, constant_0, constant_1, nullptr); + auto line_x_2 = analysis.make_constraint( + constant_1, constant_0, constant_2, nullptr); + auto line_2x_1 = analysis.make_constraint( + constant_2, constant_0, constant_1, nullptr); + auto line_2x_2 = analysis.make_constraint( + constant_2, constant_0, constant_2, nullptr); + + // Parallel vertical lines, x = 0 & x = 1, should return no intersection + auto ret_x = + analysis.IntersectConstraints(line_x_0, line_x_1, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_x->AsDependenceEmpty()); + + // Parallel vertical lines, x = 1 & x = 2, should return no intersection + auto ret_x_12 = + analysis.IntersectConstraints(line_x_1, line_x_2, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_x_12->AsDependenceEmpty()); + + // Parallel vertical lines, 2x = 1 & 2x = 2, should return no intersection + auto ret_2x_2_2x_1 = + analysis.IntersectConstraints(line_2x_2, line_2x_1, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_2x_2_2x_1->AsDependenceEmpty()); + + // same line, 2x=2 & x = 1 + auto ret_2x_2_x_1 = + analysis.IntersectConstraints(line_2x_2, line_x_1, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_2x_2_x_1->AsDependenceLine()); + + // Same vertical lines, x = 0 & x = 0, should return the line + auto ret_x_same_0 = + analysis.IntersectConstraints(line_x_0, line_x_0, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_x_same_0->AsDependenceLine()); + // EXPECT_EQ(*line_x_0, *ret_x_same_0->AsDependenceLine()); + + // Same vertical lines, x = 1 & x = 1, should return the line + auto ret_x_same_1 = + analysis.IntersectConstraints(line_x_1, line_x_1, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_x_same_1->AsDependenceLine()); + EXPECT_EQ(*line_x_1, *ret_x_same_1->AsDependenceLine()); + + // x=1 & y = 0, intersect at (1, 0) + auto ret_1_0 = analysis.IntersectConstraints(line_x_1, line_y_0, + constant_neg_1, constant_2); + + auto ret_point_1_0 = ret_1_0->AsDependencePoint(); + EXPECT_NE(nullptr, ret_point_1_0); + EXPECT_EQ(*point_1_0, *ret_point_1_0); + + // x=1 & y = 1, intersect at (1, 1) + auto ret_1_1 = analysis.IntersectConstraints(line_x_1, line_y_1, + constant_neg_1, constant_2); + + auto ret_point_1_1 = ret_1_1->AsDependencePoint(); + EXPECT_NE(nullptr, ret_point_1_1); + EXPECT_EQ(*point_1_1, *ret_point_1_1); + + // x=0 & y = 0, intersect at (0, 0) + auto ret_0_0 = analysis.IntersectConstraints(line_x_0, line_y_0, + constant_neg_1, constant_2); + + auto ret_point_0_0 = ret_0_0->AsDependencePoint(); + EXPECT_NE(nullptr, ret_point_0_0); + EXPECT_EQ(*point_0_0, *ret_point_0_0); + + // x=0 & y = 1, intersect at (0, 1) + auto ret_0_1 = analysis.IntersectConstraints(line_x_0, line_y_1, + constant_neg_1, constant_2); + auto ret_point_0_1 = ret_0_1->AsDependencePoint(); + EXPECT_NE(nullptr, ret_point_0_1); + EXPECT_EQ(*point_0_1, *ret_point_0_1); + + // x = 1 & y = 2 + auto ret_1_2 = analysis.IntersectConstraints(line_x_1, line_y_2, + constant_neg_1, constant_2); + auto ret_point_1_2 = ret_1_2->AsDependencePoint(); + EXPECT_NE(nullptr, ret_point_1_2); + EXPECT_EQ(*point_1_2, *ret_point_1_2); + + auto line_x_y_0 = analysis.make_constraint( + constant_1, constant_1, constant_0, nullptr); + auto line_x_y_1 = analysis.make_constraint( + constant_1, constant_1, constant_1, nullptr); + + // x+y=0 & x=0, intersect (0, 0) + auto ret_xy_0_x_0 = analysis.IntersectConstraints( + line_x_y_0, line_x_0, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_xy_0_x_0->AsDependencePoint()); + EXPECT_EQ(*point_0_0, *ret_xy_0_x_0); + + // x+y=0 & y=0, intersect (0, 0) + auto ret_xy_0_y_0 = analysis.IntersectConstraints( + line_x_y_0, line_y_0, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_xy_0_y_0->AsDependencePoint()); + EXPECT_EQ(*point_0_0, *ret_xy_0_y_0); + + // x+y=0 & x=1, intersect (1, -1) + auto ret_xy_0_x_1 = analysis.IntersectConstraints( + line_x_y_0, line_x_1, constant_neg_2, constant_2); + + EXPECT_NE(nullptr, ret_xy_0_x_1->AsDependencePoint()); + EXPECT_EQ(*point_1_neg_1, *ret_xy_0_x_1); + + // x+y=0 & y=1, intersect (-1, 1) + auto ret_xy_0_y_1 = analysis.IntersectConstraints( + line_x_y_0, line_y_1, constant_neg_2, constant_2); + + EXPECT_NE(nullptr, ret_xy_0_y_1->AsDependencePoint()); + EXPECT_EQ(*point_neg_1_1, *ret_xy_0_y_1); + + // x=0 & x+y=0, intersect (0, 0) + auto ret_x_0_xy_0 = analysis.IntersectConstraints( + line_x_0, line_x_y_0, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_x_0_xy_0->AsDependencePoint()); + EXPECT_EQ(*point_0_0, *ret_x_0_xy_0); + + // y=0 & x+y=0, intersect (0, 0) + auto ret_y_0_xy_0 = analysis.IntersectConstraints( + line_y_0, line_x_y_0, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_y_0_xy_0->AsDependencePoint()); + EXPECT_EQ(*point_0_0, *ret_y_0_xy_0); + + // x=1 & x+y=0, intersect (1, -1) + auto ret_x_1_xy_0 = analysis.IntersectConstraints( + line_x_1, line_x_y_0, constant_neg_2, constant_2); + + EXPECT_NE(nullptr, ret_x_1_xy_0->AsDependencePoint()); + EXPECT_EQ(*point_1_neg_1, *ret_x_1_xy_0); + + // y=1 & x+y=0, intersect (-1, 1) + auto ret_y_1_xy_0 = analysis.IntersectConstraints( + line_y_1, line_x_y_0, constant_neg_2, constant_2); + + EXPECT_NE(nullptr, ret_y_1_xy_0->AsDependencePoint()); + EXPECT_EQ(*point_neg_1_1, *ret_y_1_xy_0); + + // x+y=1 & x=0, intersect (0, 1) + auto ret_xy_1_x_0 = analysis.IntersectConstraints( + line_x_y_1, line_x_0, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_xy_1_x_0->AsDependencePoint()); + EXPECT_EQ(*point_0_1, *ret_xy_1_x_0); + + // x+y=1 & y=0, intersect (1, 0) + auto ret_xy_1_y_0 = analysis.IntersectConstraints( + line_x_y_1, line_y_0, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_xy_1_y_0->AsDependencePoint()); + EXPECT_EQ(*point_1_0, *ret_xy_1_y_0); + + // x+y=1 & x=1, intersect (1, 0) + auto ret_xy_1_x_1 = analysis.IntersectConstraints( + line_x_y_1, line_x_1, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_xy_1_x_1->AsDependencePoint()); + EXPECT_EQ(*point_1_0, *ret_xy_1_x_1); + + // x+y=1 & y=1, intersect (0, 1) + auto ret_xy_1_y_1 = analysis.IntersectConstraints( + line_x_y_1, line_y_1, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_xy_1_y_1->AsDependencePoint()); + EXPECT_EQ(*point_0_1, *ret_xy_1_y_1); + + // x=0 & x+y=1, intersect (0, 1) + auto ret_x_0_xy_1 = analysis.IntersectConstraints( + line_x_0, line_x_y_1, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_x_0_xy_1->AsDependencePoint()); + EXPECT_EQ(*point_0_1, *ret_x_0_xy_1); + + // y=0 & x+y=1, intersect (1, 0) + auto ret_y_0_xy_1 = analysis.IntersectConstraints( + line_y_0, line_x_y_1, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_y_0_xy_1->AsDependencePoint()); + EXPECT_EQ(*point_1_0, *ret_y_0_xy_1); + + // x=1 & x+y=1, intersect (1, 0) + auto ret_x_1_xy_1 = analysis.IntersectConstraints( + line_x_1, line_x_y_1, constant_neg_2, constant_2); + + EXPECT_NE(nullptr, ret_x_1_xy_1->AsDependencePoint()); + EXPECT_EQ(*point_1_0, *ret_x_1_xy_1); + + // y=1 & x+y=1, intersect (0, 1) + auto ret_y_1_xy_1 = analysis.IntersectConstraints( + line_y_1, line_x_y_1, constant_neg_2, constant_2); + + EXPECT_NE(nullptr, ret_y_1_xy_1->AsDependencePoint()); + EXPECT_EQ(*point_0_1, *ret_y_1_xy_1); + } + + { + // Line and point + auto a = scalar_evolution->CreateConstant(3); + auto b = scalar_evolution->CreateConstant(10); + auto c = scalar_evolution->CreateConstant(16); + + auto line = analysis.make_constraint(a, b, c, nullptr); + + // Point on line + auto x = scalar_evolution->CreateConstant(2); + auto y = scalar_evolution->CreateConstant(1); + auto point_0 = analysis.make_constraint(x, y, nullptr); + + auto ret_0 = analysis.IntersectConstraints(line, point_0, nullptr, nullptr); + auto ret_1 = analysis.IntersectConstraints(point_0, line, nullptr, nullptr); + + auto ret_point_0 = ret_0->AsDependencePoint(); + auto ret_point_1 = ret_1->AsDependencePoint(); + ASSERT_NE(nullptr, ret_point_0); + ASSERT_NE(nullptr, ret_point_1); + + EXPECT_EQ(*x, *ret_point_0->GetSource()); + EXPECT_EQ(*y, *ret_point_0->GetDestination()); + + EXPECT_EQ(*x, *ret_point_1->GetSource()); + EXPECT_EQ(*y, *ret_point_1->GetDestination()); + + // Point not on line + auto point_1 = analysis.make_constraint(a, a, nullptr); + + auto ret_2 = analysis.IntersectConstraints(line, point_1, nullptr, nullptr); + auto ret_3 = analysis.IntersectConstraints(point_1, line, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_2->AsDependenceEmpty()); + EXPECT_NE(nullptr, ret_3->AsDependenceEmpty()); + + // Non-constant + auto unknown = scalar_evolution->CreateCantComputeNode(); + + auto point_2 = + analysis.make_constraint(unknown, x, nullptr); + + auto ret_4 = analysis.IntersectConstraints(line, point_2, nullptr, nullptr); + auto ret_5 = analysis.IntersectConstraints(point_2, line, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_4->AsDependenceNone()); + EXPECT_NE(nullptr, ret_5->AsDependenceNone()); + } + + { + // Distance and point + auto d = scalar_evolution->CreateConstant(5); + auto distance = analysis.make_constraint(d, nullptr); + + // Point on line + auto x = scalar_evolution->CreateConstant(10); + auto point_0 = analysis.make_constraint(d, x, nullptr); + + auto ret_0 = + analysis.IntersectConstraints(distance, point_0, nullptr, nullptr); + auto ret_1 = + analysis.IntersectConstraints(point_0, distance, nullptr, nullptr); + + auto ret_point_0 = ret_0->AsDependencePoint(); + auto ret_point_1 = ret_1->AsDependencePoint(); + ASSERT_NE(nullptr, ret_point_0); + ASSERT_NE(nullptr, ret_point_1); + + // Point not on line + auto point_1 = analysis.make_constraint(x, x, nullptr); + + auto ret_2 = + analysis.IntersectConstraints(distance, point_1, nullptr, nullptr); + auto ret_3 = + analysis.IntersectConstraints(point_1, distance, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_2->AsDependenceEmpty()); + EXPECT_NE(nullptr, ret_3->AsDependenceEmpty()); + + // Non-constant + auto unknown = scalar_evolution->CreateCantComputeNode(); + auto unknown_distance = + analysis.make_constraint(unknown, nullptr); + + auto ret_4 = analysis.IntersectConstraints(unknown_distance, point_1, + nullptr, nullptr); + auto ret_5 = analysis.IntersectConstraints(point_1, unknown_distance, + nullptr, nullptr); + + EXPECT_NE(nullptr, ret_4->AsDependenceNone()); + EXPECT_NE(nullptr, ret_5->AsDependenceNone()); + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/dependence_analysis_helpers.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/dependence_analysis_helpers.cpp new file mode 100644 index 000000000..715cf541d --- /dev/null +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/dependence_analysis_helpers.cpp @@ -0,0 +1,3017 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/iterator.h" +#include "source/opt/loop_dependence.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/tree_iterator.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using DependencyAnalysisHelpers = ::testing::Test; + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void a() { + int[10][10] arr; + int i = 0; + int j = 0; + for (; i < 10 && j < 10; i++, j++) { + arr[i][j] = arr[i][j]; + } +} +void b() { + int[10] arr; + for (int i = 0; i < 10; i+=2) { + arr[i] = arr[i]; + } +} +void main(){ + a(); + b(); +} +*/ +TEST(DependencyAnalysisHelpers, UnsupportedLoops) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %12 "i" + OpName %14 "j" + OpName %32 "arr" + OpName %45 "i" + OpName %54 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %13 = OpConstant %10 0 + %21 = OpConstant %10 10 + %22 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %10 %28 + %30 = OpTypeArray %29 %28 + %31 = OpTypePointer Function %30 + %41 = OpConstant %10 1 + %53 = OpTypePointer Function %29 + %60 = OpConstant %10 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %63 = OpFunctionCall %2 %6 + %64 = OpFunctionCall %2 %8 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %12 = OpVariable %11 Function + %14 = OpVariable %11 Function + %32 = OpVariable %31 Function + OpStore %12 %13 + OpStore %14 %13 + OpBranch %15 + %15 = OpLabel + %65 = OpPhi %10 %13 %7 %42 %18 + %66 = OpPhi %10 %13 %7 %44 %18 + OpLoopMerge %17 %18 None + OpBranch %19 + %19 = OpLabel + %23 = OpSLessThan %22 %65 %21 + %25 = OpSLessThan %22 %66 %21 + %26 = OpLogicalAnd %22 %23 %25 + OpBranchConditional %26 %16 %17 + %16 = OpLabel + %37 = OpAccessChain %11 %32 %65 %66 + %38 = OpLoad %10 %37 + %39 = OpAccessChain %11 %32 %65 %66 + OpStore %39 %38 + OpBranch %18 + %18 = OpLabel + %42 = OpIAdd %10 %65 %41 + OpStore %12 %42 + %44 = OpIAdd %10 %66 %41 + OpStore %14 %44 + OpBranch %15 + %17 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %45 = OpVariable %11 Function + %54 = OpVariable %53 Function + OpStore %45 %13 + OpBranch %46 + %46 = OpLabel + %67 = OpPhi %10 %13 %9 %62 %49 + OpLoopMerge %48 %49 None + OpBranch %50 + %50 = OpLabel + %52 = OpSLessThan %22 %67 %21 + OpBranchConditional %52 %47 %48 + %47 = OpLabel + %57 = OpAccessChain %11 %54 %67 + %58 = OpLoad %10 %57 + %59 = OpAccessChain %11 %54 %67 + OpStore %59 %58 + OpBranch %49 + %49 = OpLabel + %62 = OpIAdd %10 %67 %60 + OpStore %45 %62 + OpBranch %46 + %48 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + { + // Function a + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[1] = {nullptr}; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 16)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + // 38 -> 39 + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.IsSupportedLoop(loops[0])); + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(38), + store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::UNKNOWN); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::ALL); + } + { + // Function b + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[1] = {nullptr}; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 47)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + // 58 -> 59 + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.IsSupportedLoop(loops[0])); + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(58), + store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::UNKNOWN); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::ALL); + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void a() { + for (int i = -10; i < 0; i++) { + + } +} +void b() { + for (int i = -5; i < 5; i++) { + + } +} +void c() { + for (int i = 0; i < 10; i++) { + + } +} +void d() { + for (int i = 5; i < 15; i++) { + + } +} +void e() { + for (int i = -10; i <= 0; i++) { + + } +} +void f() { + for (int i = -5; i <= 5; i++) { + + } +} +void g() { + for (int i = 0; i <= 10; i++) { + + } +} +void h() { + for (int i = 5; i <= 15; i++) { + + } +} +void i() { + for (int i = 0; i > -10; i--) { + + } +} +void j() { + for (int i = 5; i > -5; i--) { + + } +} +void k() { + for (int i = 10; i > 0; i--) { + + } +} +void l() { + for (int i = 15; i > 5; i--) { + + } +} +void m() { + for (int i = 0; i >= -10; i--) { + + } +} +void n() { + for (int i = 5; i >= -5; i--) { + + } +} +void o() { + for (int i = 10; i >= 0; i--) { + + } +} +void p() { + for (int i = 15; i >= 5; i--) { + + } +} +void main(){ + a(); + b(); + c(); + d(); + e(); + f(); + g(); + h(); + i(); + j(); + k(); + l(); + m(); + n(); + o(); + p(); +} +*/ +TEST(DependencyAnalysisHelpers, loop_information) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %14 "e(" + OpName %16 "f(" + OpName %18 "g(" + OpName %20 "h(" + OpName %22 "i(" + OpName %24 "j(" + OpName %26 "k(" + OpName %28 "l(" + OpName %30 "m(" + OpName %32 "n(" + OpName %34 "o(" + OpName %36 "p(" + OpName %40 "i" + OpName %54 "i" + OpName %66 "i" + OpName %77 "i" + OpName %88 "i" + OpName %98 "i" + OpName %108 "i" + OpName %118 "i" + OpName %128 "i" + OpName %138 "i" + OpName %148 "i" + OpName %158 "i" + OpName %168 "i" + OpName %178 "i" + OpName %188 "i" + OpName %198 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %38 = OpTypeInt 32 1 + %39 = OpTypePointer Function %38 + %41 = OpConstant %38 -10 + %48 = OpConstant %38 0 + %49 = OpTypeBool + %52 = OpConstant %38 1 + %55 = OpConstant %38 -5 + %62 = OpConstant %38 5 + %73 = OpConstant %38 10 + %84 = OpConstant %38 15 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %208 = OpFunctionCall %2 %6 + %209 = OpFunctionCall %2 %8 + %210 = OpFunctionCall %2 %10 + %211 = OpFunctionCall %2 %12 + %212 = OpFunctionCall %2 %14 + %213 = OpFunctionCall %2 %16 + %214 = OpFunctionCall %2 %18 + %215 = OpFunctionCall %2 %20 + %216 = OpFunctionCall %2 %22 + %217 = OpFunctionCall %2 %24 + %218 = OpFunctionCall %2 %26 + %219 = OpFunctionCall %2 %28 + %220 = OpFunctionCall %2 %30 + %221 = OpFunctionCall %2 %32 + %222 = OpFunctionCall %2 %34 + %223 = OpFunctionCall %2 %36 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %40 = OpVariable %39 Function + OpStore %40 %41 + OpBranch %42 + %42 = OpLabel + %224 = OpPhi %38 %41 %7 %53 %45 + OpLoopMerge %44 %45 None + OpBranch %46 + %46 = OpLabel + %50 = OpSLessThan %49 %224 %48 + OpBranchConditional %50 %43 %44 + %43 = OpLabel + OpBranch %45 + %45 = OpLabel + %53 = OpIAdd %38 %224 %52 + OpStore %40 %53 + OpBranch %42 + %44 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %54 = OpVariable %39 Function + OpStore %54 %55 + OpBranch %56 + %56 = OpLabel + %225 = OpPhi %38 %55 %9 %65 %59 + OpLoopMerge %58 %59 None + OpBranch %60 + %60 = OpLabel + %63 = OpSLessThan %49 %225 %62 + OpBranchConditional %63 %57 %58 + %57 = OpLabel + OpBranch %59 + %59 = OpLabel + %65 = OpIAdd %38 %225 %52 + OpStore %54 %65 + OpBranch %56 + %58 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %66 = OpVariable %39 Function + OpStore %66 %48 + OpBranch %67 + %67 = OpLabel + %226 = OpPhi %38 %48 %11 %76 %70 + OpLoopMerge %69 %70 None + OpBranch %71 + %71 = OpLabel + %74 = OpSLessThan %49 %226 %73 + OpBranchConditional %74 %68 %69 + %68 = OpLabel + OpBranch %70 + %70 = OpLabel + %76 = OpIAdd %38 %226 %52 + OpStore %66 %76 + OpBranch %67 + %69 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %77 = OpVariable %39 Function + OpStore %77 %62 + OpBranch %78 + %78 = OpLabel + %227 = OpPhi %38 %62 %13 %87 %81 + OpLoopMerge %80 %81 None + OpBranch %82 + %82 = OpLabel + %85 = OpSLessThan %49 %227 %84 + OpBranchConditional %85 %79 %80 + %79 = OpLabel + OpBranch %81 + %81 = OpLabel + %87 = OpIAdd %38 %227 %52 + OpStore %77 %87 + OpBranch %78 + %80 = OpLabel + OpReturn + OpFunctionEnd + %14 = OpFunction %2 None %3 + %15 = OpLabel + %88 = OpVariable %39 Function + OpStore %88 %41 + OpBranch %89 + %89 = OpLabel + %228 = OpPhi %38 %41 %15 %97 %92 + OpLoopMerge %91 %92 None + OpBranch %93 + %93 = OpLabel + %95 = OpSLessThanEqual %49 %228 %48 + OpBranchConditional %95 %90 %91 + %90 = OpLabel + OpBranch %92 + %92 = OpLabel + %97 = OpIAdd %38 %228 %52 + OpStore %88 %97 + OpBranch %89 + %91 = OpLabel + OpReturn + OpFunctionEnd + %16 = OpFunction %2 None %3 + %17 = OpLabel + %98 = OpVariable %39 Function + OpStore %98 %55 + OpBranch %99 + %99 = OpLabel + %229 = OpPhi %38 %55 %17 %107 %102 + OpLoopMerge %101 %102 None + OpBranch %103 + %103 = OpLabel + %105 = OpSLessThanEqual %49 %229 %62 + OpBranchConditional %105 %100 %101 + %100 = OpLabel + OpBranch %102 + %102 = OpLabel + %107 = OpIAdd %38 %229 %52 + OpStore %98 %107 + OpBranch %99 + %101 = OpLabel + OpReturn + OpFunctionEnd + %18 = OpFunction %2 None %3 + %19 = OpLabel + %108 = OpVariable %39 Function + OpStore %108 %48 + OpBranch %109 + %109 = OpLabel + %230 = OpPhi %38 %48 %19 %117 %112 + OpLoopMerge %111 %112 None + OpBranch %113 + %113 = OpLabel + %115 = OpSLessThanEqual %49 %230 %73 + OpBranchConditional %115 %110 %111 + %110 = OpLabel + OpBranch %112 + %112 = OpLabel + %117 = OpIAdd %38 %230 %52 + OpStore %108 %117 + OpBranch %109 + %111 = OpLabel + OpReturn + OpFunctionEnd + %20 = OpFunction %2 None %3 + %21 = OpLabel + %118 = OpVariable %39 Function + OpStore %118 %62 + OpBranch %119 + %119 = OpLabel + %231 = OpPhi %38 %62 %21 %127 %122 + OpLoopMerge %121 %122 None + OpBranch %123 + %123 = OpLabel + %125 = OpSLessThanEqual %49 %231 %84 + OpBranchConditional %125 %120 %121 + %120 = OpLabel + OpBranch %122 + %122 = OpLabel + %127 = OpIAdd %38 %231 %52 + OpStore %118 %127 + OpBranch %119 + %121 = OpLabel + OpReturn + OpFunctionEnd + %22 = OpFunction %2 None %3 + %23 = OpLabel + %128 = OpVariable %39 Function + OpStore %128 %48 + OpBranch %129 + %129 = OpLabel + %232 = OpPhi %38 %48 %23 %137 %132 + OpLoopMerge %131 %132 None + OpBranch %133 + %133 = OpLabel + %135 = OpSGreaterThan %49 %232 %41 + OpBranchConditional %135 %130 %131 + %130 = OpLabel + OpBranch %132 + %132 = OpLabel + %137 = OpISub %38 %232 %52 + OpStore %128 %137 + OpBranch %129 + %131 = OpLabel + OpReturn + OpFunctionEnd + %24 = OpFunction %2 None %3 + %25 = OpLabel + %138 = OpVariable %39 Function + OpStore %138 %62 + OpBranch %139 + %139 = OpLabel + %233 = OpPhi %38 %62 %25 %147 %142 + OpLoopMerge %141 %142 None + OpBranch %143 + %143 = OpLabel + %145 = OpSGreaterThan %49 %233 %55 + OpBranchConditional %145 %140 %141 + %140 = OpLabel + OpBranch %142 + %142 = OpLabel + %147 = OpISub %38 %233 %52 + OpStore %138 %147 + OpBranch %139 + %141 = OpLabel + OpReturn + OpFunctionEnd + %26 = OpFunction %2 None %3 + %27 = OpLabel + %148 = OpVariable %39 Function + OpStore %148 %73 + OpBranch %149 + %149 = OpLabel + %234 = OpPhi %38 %73 %27 %157 %152 + OpLoopMerge %151 %152 None + OpBranch %153 + %153 = OpLabel + %155 = OpSGreaterThan %49 %234 %48 + OpBranchConditional %155 %150 %151 + %150 = OpLabel + OpBranch %152 + %152 = OpLabel + %157 = OpISub %38 %234 %52 + OpStore %148 %157 + OpBranch %149 + %151 = OpLabel + OpReturn + OpFunctionEnd + %28 = OpFunction %2 None %3 + %29 = OpLabel + %158 = OpVariable %39 Function + OpStore %158 %84 + OpBranch %159 + %159 = OpLabel + %235 = OpPhi %38 %84 %29 %167 %162 + OpLoopMerge %161 %162 None + OpBranch %163 + %163 = OpLabel + %165 = OpSGreaterThan %49 %235 %62 + OpBranchConditional %165 %160 %161 + %160 = OpLabel + OpBranch %162 + %162 = OpLabel + %167 = OpISub %38 %235 %52 + OpStore %158 %167 + OpBranch %159 + %161 = OpLabel + OpReturn + OpFunctionEnd + %30 = OpFunction %2 None %3 + %31 = OpLabel + %168 = OpVariable %39 Function + OpStore %168 %48 + OpBranch %169 + %169 = OpLabel + %236 = OpPhi %38 %48 %31 %177 %172 + OpLoopMerge %171 %172 None + OpBranch %173 + %173 = OpLabel + %175 = OpSGreaterThanEqual %49 %236 %41 + OpBranchConditional %175 %170 %171 + %170 = OpLabel + OpBranch %172 + %172 = OpLabel + %177 = OpISub %38 %236 %52 + OpStore %168 %177 + OpBranch %169 + %171 = OpLabel + OpReturn + OpFunctionEnd + %32 = OpFunction %2 None %3 + %33 = OpLabel + %178 = OpVariable %39 Function + OpStore %178 %62 + OpBranch %179 + %179 = OpLabel + %237 = OpPhi %38 %62 %33 %187 %182 + OpLoopMerge %181 %182 None + OpBranch %183 + %183 = OpLabel + %185 = OpSGreaterThanEqual %49 %237 %55 + OpBranchConditional %185 %180 %181 + %180 = OpLabel + OpBranch %182 + %182 = OpLabel + %187 = OpISub %38 %237 %52 + OpStore %178 %187 + OpBranch %179 + %181 = OpLabel + OpReturn + OpFunctionEnd + %34 = OpFunction %2 None %3 + %35 = OpLabel + %188 = OpVariable %39 Function + OpStore %188 %73 + OpBranch %189 + %189 = OpLabel + %238 = OpPhi %38 %73 %35 %197 %192 + OpLoopMerge %191 %192 None + OpBranch %193 + %193 = OpLabel + %195 = OpSGreaterThanEqual %49 %238 %48 + OpBranchConditional %195 %190 %191 + %190 = OpLabel + OpBranch %192 + %192 = OpLabel + %197 = OpISub %38 %238 %52 + OpStore %188 %197 + OpBranch %189 + %191 = OpLabel + OpReturn + OpFunctionEnd + %36 = OpFunction %2 None %3 + %37 = OpLabel + %198 = OpVariable %39 Function + OpStore %198 %84 + OpBranch %199 + %199 = OpLabel + %239 = OpPhi %38 %84 %37 %207 %202 + OpLoopMerge %201 %202 None + OpBranch %203 + %203 = OpLabel + %205 = OpSGreaterThanEqual %49 %239 %62 + OpBranchConditional %205 %200 %201 + %200 = OpLabel + OpBranch %202 + %202 = OpLabel + %207 = OpISub %38 %239 %52 + OpStore %198 %207 + OpBranch %199 + %201 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + { + // Function a + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -10); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -1); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(-10)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(-1)); + } + { + // Function b + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -5); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 4); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(-5)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(4)); + } + { + // Function c + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 0); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 9); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(0)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(9)); + } + { + // Function d + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 5); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 14); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(5)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(14)); + } + { + // Function e + const Function* f = spvtest::GetFunction(module, 14); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -10); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 0); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(-10)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(0)); + } + { + // Function f + const Function* f = spvtest::GetFunction(module, 16); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -5); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 5); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(-5)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(5)); + } + { + // Function g + const Function* f = spvtest::GetFunction(module, 18); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 0); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(0)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(10)); + } + { + // Function h + const Function* f = spvtest::GetFunction(module, 20); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 5); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 15); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(5)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(15)); + } + { + // Function i + const Function* f = spvtest::GetFunction(module, 22); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 0); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -9); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(0)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(-9)); + } + { + // Function j + const Function* f = spvtest::GetFunction(module, 24); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 5); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -4); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(5)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(-4)); + } + { + // Function k + const Function* f = spvtest::GetFunction(module, 26); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 1); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(10)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(1)); + } + { + // Function l + const Function* f = spvtest::GetFunction(module, 28); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 15); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 6); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(15)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(6)); + } + { + // Function m + const Function* f = spvtest::GetFunction(module, 30); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 0); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -10); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(0)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(-10)); + } + { + // Function n + const Function* f = spvtest::GetFunction(module, 32); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 5); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -5); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(5)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(-5)); + } + { + // Function o + const Function* f = spvtest::GetFunction(module, 34); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 0); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(10)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(0)); + } + { + // Function p + const Function* f = spvtest::GetFunction(module, 36); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 15); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 5); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(15)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(5)); + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void main(){ + for (int i = 0; i < 10; i++) { + + } +} +*/ +TEST(DependencyAnalysisHelpers, bounds_checks) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %22 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %22 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %22 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + // We need a shader that includes a loop for this test so we can build a + // LoopDependenceAnalaysis + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_TRUE(analysis.IsWithinBounds(0, 0, 0)); + EXPECT_TRUE(analysis.IsWithinBounds(0, -1, 0)); + EXPECT_TRUE(analysis.IsWithinBounds(0, 0, 1)); + EXPECT_TRUE(analysis.IsWithinBounds(0, -1, 1)); + EXPECT_TRUE(analysis.IsWithinBounds(-2, -2, -2)); + EXPECT_TRUE(analysis.IsWithinBounds(-2, -3, 0)); + EXPECT_TRUE(analysis.IsWithinBounds(-2, 0, -3)); + EXPECT_TRUE(analysis.IsWithinBounds(2, 2, 2)); + EXPECT_TRUE(analysis.IsWithinBounds(2, 3, 0)); + + EXPECT_FALSE(analysis.IsWithinBounds(2, 3, 3)); + EXPECT_FALSE(analysis.IsWithinBounds(0, 1, 5)); + EXPECT_FALSE(analysis.IsWithinBounds(0, -1, -4)); + EXPECT_FALSE(analysis.IsWithinBounds(-2, -4, -3)); +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +layout(location = 0) in vec4 in_vec; +// Loop iterates from constant to symbolic +void a() { + int N = int(in_vec.x); + int arr[10]; + for (int i = 0; i < N; i++) { // Bounds are N - 0 - 1 + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void b() { + int N = int(in_vec.x); + int arr[10]; + for (int i = 0; i <= N; i++) { // Bounds are N - 0 + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void c() { + int N = int(in_vec.x); + int arr[10]; + for (int i = 9; i > N; i--) { // Bounds are 9 - N - 1 + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void d() { + int N = int(in_vec.x); + int arr[10]; + for (int i = 9; i >= N; i--) { // Bounds are 9 - N + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void main(){ + a(); + b(); + c(); + d(); +} +*/ +TEST(DependencyAnalysisHelpers, const_to_symbolic) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %20 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %16 "N" + OpName %20 "in_vec" + OpName %27 "i" + OpName %41 "arr" + OpName %59 "N" + OpName %63 "i" + OpName %72 "arr" + OpName %89 "N" + OpName %93 "i" + OpName %103 "arr" + OpName %120 "N" + OpName %124 "i" + OpName %133 "arr" + OpDecorate %20 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %14 = OpTypeInt 32 1 + %15 = OpTypePointer Function %14 + %17 = OpTypeFloat 32 + %18 = OpTypeVector %17 4 + %19 = OpTypePointer Input %18 + %20 = OpVariable %19 Input + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 0 + %23 = OpTypePointer Input %17 + %28 = OpConstant %14 0 + %36 = OpTypeBool + %38 = OpConstant %21 10 + %39 = OpTypeArray %14 %38 + %40 = OpTypePointer Function %39 + %57 = OpConstant %14 1 + %94 = OpConstant %14 9 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %150 = OpFunctionCall %2 %6 + %151 = OpFunctionCall %2 %8 + %152 = OpFunctionCall %2 %10 + %153 = OpFunctionCall %2 %12 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %16 = OpVariable %15 Function + %27 = OpVariable %15 Function + %41 = OpVariable %40 Function + %24 = OpAccessChain %23 %20 %22 + %25 = OpLoad %17 %24 + %26 = OpConvertFToS %14 %25 + OpStore %16 %26 + OpStore %27 %28 + OpBranch %29 + %29 = OpLabel + %154 = OpPhi %14 %28 %7 %58 %32 + OpLoopMerge %31 %32 None + OpBranch %33 + %33 = OpLabel + %37 = OpSLessThan %36 %154 %26 + OpBranchConditional %37 %30 %31 + %30 = OpLabel + %45 = OpIAdd %14 %154 %26 + %46 = OpAccessChain %15 %41 %45 + %47 = OpLoad %14 %46 + %48 = OpAccessChain %15 %41 %154 + OpStore %48 %47 + %51 = OpIAdd %14 %154 %26 + %53 = OpAccessChain %15 %41 %154 + %54 = OpLoad %14 %53 + %55 = OpAccessChain %15 %41 %51 + OpStore %55 %54 + OpBranch %32 + %32 = OpLabel + %58 = OpIAdd %14 %154 %57 + OpStore %27 %58 + OpBranch %29 + %31 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %59 = OpVariable %15 Function + %63 = OpVariable %15 Function + %72 = OpVariable %40 Function + %60 = OpAccessChain %23 %20 %22 + %61 = OpLoad %17 %60 + %62 = OpConvertFToS %14 %61 + OpStore %59 %62 + OpStore %63 %28 + OpBranch %64 + %64 = OpLabel + %155 = OpPhi %14 %28 %9 %88 %67 + OpLoopMerge %66 %67 None + OpBranch %68 + %68 = OpLabel + %71 = OpSLessThanEqual %36 %155 %62 + OpBranchConditional %71 %65 %66 + %65 = OpLabel + %76 = OpIAdd %14 %155 %62 + %77 = OpAccessChain %15 %72 %76 + %78 = OpLoad %14 %77 + %79 = OpAccessChain %15 %72 %155 + OpStore %79 %78 + %82 = OpIAdd %14 %155 %62 + %84 = OpAccessChain %15 %72 %155 + %85 = OpLoad %14 %84 + %86 = OpAccessChain %15 %72 %82 + OpStore %86 %85 + OpBranch %67 + %67 = OpLabel + %88 = OpIAdd %14 %155 %57 + OpStore %63 %88 + OpBranch %64 + %66 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %89 = OpVariable %15 Function + %93 = OpVariable %15 Function + %103 = OpVariable %40 Function + %90 = OpAccessChain %23 %20 %22 + %91 = OpLoad %17 %90 + %92 = OpConvertFToS %14 %91 + OpStore %89 %92 + OpStore %93 %94 + OpBranch %95 + %95 = OpLabel + %156 = OpPhi %14 %94 %11 %119 %98 + OpLoopMerge %97 %98 None + OpBranch %99 + %99 = OpLabel + %102 = OpSGreaterThan %36 %156 %92 + OpBranchConditional %102 %96 %97 + %96 = OpLabel + %107 = OpIAdd %14 %156 %92 + %108 = OpAccessChain %15 %103 %107 + %109 = OpLoad %14 %108 + %110 = OpAccessChain %15 %103 %156 + OpStore %110 %109 + %113 = OpIAdd %14 %156 %92 + %115 = OpAccessChain %15 %103 %156 + %116 = OpLoad %14 %115 + %117 = OpAccessChain %15 %103 %113 + OpStore %117 %116 + OpBranch %98 + %98 = OpLabel + %119 = OpISub %14 %156 %57 + OpStore %93 %119 + OpBranch %95 + %97 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %120 = OpVariable %15 Function + %124 = OpVariable %15 Function + %133 = OpVariable %40 Function + %121 = OpAccessChain %23 %20 %22 + %122 = OpLoad %17 %121 + %123 = OpConvertFToS %14 %122 + OpStore %120 %123 + OpStore %124 %94 + OpBranch %125 + %125 = OpLabel + %157 = OpPhi %14 %94 %13 %149 %128 + OpLoopMerge %127 %128 None + OpBranch %129 + %129 = OpLabel + %132 = OpSGreaterThanEqual %36 %157 %123 + OpBranchConditional %132 %126 %127 + %126 = OpLabel + %137 = OpIAdd %14 %157 %123 + %138 = OpAccessChain %15 %133 %137 + %139 = OpLoad %14 %138 + %140 = OpAccessChain %15 %133 %157 + OpStore %140 %139 + %143 = OpIAdd %14 %157 %123 + %145 = OpAccessChain %15 %133 %157 + %146 = OpLoad %14 %145 + %147 = OpAccessChain %15 %133 %143 + OpStore %147 %146 + OpBranch %128 + %128 = OpLabel + %149 = OpISub %14 %157 %57 + OpStore %124 %149 + OpBranch %125 + %127 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + { + // Function a + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 30)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 47 -> 48 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(47) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent and supported. + EXPECT_TRUE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 54 -> 55 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(54) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent but not supported. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function b + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 65)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 78 -> 79 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(78) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 85 -> 86 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(85) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function c + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 96)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 109 -> 110 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(109) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent but not supported. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 116 -> 117 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(116) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent but not supported. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function d + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 126)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 139 -> 140 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(139) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 146 -> 147 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(146) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +layout(location = 0) in vec4 in_vec; +// Loop iterates from symbolic to constant +void a() { + int N = int(in_vec.x); + int arr[10]; + for (int i = N; i < 9; i++) { // Bounds are 9 - N - 1 + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void b() { + int N = int(in_vec.x); + int arr[10]; + for (int i = N; i <= 9; i++) { // Bounds are 9 - N + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void c() { + int N = int(in_vec.x); + int arr[10]; + for (int i = N; i > 0; i--) { // Bounds are N - 0 - 1 + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void d() { + int N = int(in_vec.x); + int arr[10]; + for (int i = N; i >= 0; i--) { // Bounds are N - 0 + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void main(){ + a(); + b(); + c(); + d(); +} +*/ +TEST(DependencyAnalysisHelpers, symbolic_to_const) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %20 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %16 "N" + OpName %20 "in_vec" + OpName %27 "i" + OpName %41 "arr" + OpName %59 "N" + OpName %63 "i" + OpName %72 "arr" + OpName %89 "N" + OpName %93 "i" + OpName %103 "arr" + OpName %120 "N" + OpName %124 "i" + OpName %133 "arr" + OpDecorate %20 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %14 = OpTypeInt 32 1 + %15 = OpTypePointer Function %14 + %17 = OpTypeFloat 32 + %18 = OpTypeVector %17 4 + %19 = OpTypePointer Input %18 + %20 = OpVariable %19 Input + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 0 + %23 = OpTypePointer Input %17 + %35 = OpConstant %14 9 + %36 = OpTypeBool + %38 = OpConstant %21 10 + %39 = OpTypeArray %14 %38 + %40 = OpTypePointer Function %39 + %57 = OpConstant %14 1 + %101 = OpConstant %14 0 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %150 = OpFunctionCall %2 %6 + %151 = OpFunctionCall %2 %8 + %152 = OpFunctionCall %2 %10 + %153 = OpFunctionCall %2 %12 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %16 = OpVariable %15 Function + %27 = OpVariable %15 Function + %41 = OpVariable %40 Function + %24 = OpAccessChain %23 %20 %22 + %25 = OpLoad %17 %24 + %26 = OpConvertFToS %14 %25 + OpStore %16 %26 + OpStore %27 %26 + OpBranch %29 + %29 = OpLabel + %154 = OpPhi %14 %26 %7 %58 %32 + OpLoopMerge %31 %32 None + OpBranch %33 + %33 = OpLabel + %37 = OpSLessThan %36 %154 %35 + OpBranchConditional %37 %30 %31 + %30 = OpLabel + %45 = OpIAdd %14 %154 %26 + %46 = OpAccessChain %15 %41 %45 + %47 = OpLoad %14 %46 + %48 = OpAccessChain %15 %41 %154 + OpStore %48 %47 + %51 = OpIAdd %14 %154 %26 + %53 = OpAccessChain %15 %41 %154 + %54 = OpLoad %14 %53 + %55 = OpAccessChain %15 %41 %51 + OpStore %55 %54 + OpBranch %32 + %32 = OpLabel + %58 = OpIAdd %14 %154 %57 + OpStore %27 %58 + OpBranch %29 + %31 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %59 = OpVariable %15 Function + %63 = OpVariable %15 Function + %72 = OpVariable %40 Function + %60 = OpAccessChain %23 %20 %22 + %61 = OpLoad %17 %60 + %62 = OpConvertFToS %14 %61 + OpStore %59 %62 + OpStore %63 %62 + OpBranch %65 + %65 = OpLabel + %155 = OpPhi %14 %62 %9 %88 %68 + OpLoopMerge %67 %68 None + OpBranch %69 + %69 = OpLabel + %71 = OpSLessThanEqual %36 %155 %35 + OpBranchConditional %71 %66 %67 + %66 = OpLabel + %76 = OpIAdd %14 %155 %62 + %77 = OpAccessChain %15 %72 %76 + %78 = OpLoad %14 %77 + %79 = OpAccessChain %15 %72 %155 + OpStore %79 %78 + %82 = OpIAdd %14 %155 %62 + %84 = OpAccessChain %15 %72 %155 + %85 = OpLoad %14 %84 + %86 = OpAccessChain %15 %72 %82 + OpStore %86 %85 + OpBranch %68 + %68 = OpLabel + %88 = OpIAdd %14 %155 %57 + OpStore %63 %88 + OpBranch %65 + %67 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %89 = OpVariable %15 Function + %93 = OpVariable %15 Function + %103 = OpVariable %40 Function + %90 = OpAccessChain %23 %20 %22 + %91 = OpLoad %17 %90 + %92 = OpConvertFToS %14 %91 + OpStore %89 %92 + OpStore %93 %92 + OpBranch %95 + %95 = OpLabel + %156 = OpPhi %14 %92 %11 %119 %98 + OpLoopMerge %97 %98 None + OpBranch %99 + %99 = OpLabel + %102 = OpSGreaterThan %36 %156 %101 + OpBranchConditional %102 %96 %97 + %96 = OpLabel + %107 = OpIAdd %14 %156 %92 + %108 = OpAccessChain %15 %103 %107 + %109 = OpLoad %14 %108 + %110 = OpAccessChain %15 %103 %156 + OpStore %110 %109 + %113 = OpIAdd %14 %156 %92 + %115 = OpAccessChain %15 %103 %156 + %116 = OpLoad %14 %115 + %117 = OpAccessChain %15 %103 %113 + OpStore %117 %116 + OpBranch %98 + %98 = OpLabel + %119 = OpISub %14 %156 %57 + OpStore %93 %119 + OpBranch %95 + %97 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %120 = OpVariable %15 Function + %124 = OpVariable %15 Function + %133 = OpVariable %40 Function + %121 = OpAccessChain %23 %20 %22 + %122 = OpLoad %17 %121 + %123 = OpConvertFToS %14 %122 + OpStore %120 %123 + OpStore %124 %123 + OpBranch %126 + %126 = OpLabel + %157 = OpPhi %14 %123 %13 %149 %129 + OpLoopMerge %128 %129 None + OpBranch %130 + %130 = OpLabel + %132 = OpSGreaterThanEqual %36 %157 %101 + OpBranchConditional %132 %127 %128 + %127 = OpLabel + %137 = OpIAdd %14 %157 %123 + %138 = OpAccessChain %15 %133 %137 + %139 = OpLoad %14 %138 + %140 = OpAccessChain %15 %133 %157 + OpStore %140 %139 + %143 = OpIAdd %14 %157 %123 + %145 = OpAccessChain %15 %133 %157 + %146 = OpLoad %14 %145 + %147 = OpAccessChain %15 %133 %143 + OpStore %147 %146 + OpBranch %129 + %129 = OpLabel + %149 = OpISub %14 %157 %57 + OpStore %124 %149 + OpBranch %126 + %128 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + { + // Function a + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 30)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 47 -> 48 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(47) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent but not supported. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 54 -> 55 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(54) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent but not supported. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function b + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 66)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 78 -> 79 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(78) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 85 -> 86 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(85) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function c + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 96)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 109 -> 110 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(109) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent and supported. + EXPECT_TRUE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 116 -> 117 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(116) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent but not supported. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function d + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 127)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 139 -> 140 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(139) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 146 -> 147 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(146) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +layout(location = 0) in vec4 in_vec; +// Loop iterates from symbolic to symbolic +void a() { + int M = int(in_vec.x); + int N = int(in_vec.y); + int arr[10]; + for (int i = M; i < N; i++) { // Bounds are N - M - 1 + arr[i+M+N] = arr[i+M+2*N]; // |distance| = N + arr[i+M+2*N] = arr[i+M+N]; // |distance| = N + } +} +void b() { + int M = int(in_vec.x); + int N = int(in_vec.y); + int arr[10]; + for (int i = M; i <= N; i++) { // Bounds are N - M + arr[i+M+N] = arr[i+M+2*N]; // |distance| = N + arr[i+M+2*N] = arr[i+M+N]; // |distance| = N + } +} +void c() { + int M = int(in_vec.x); + int N = int(in_vec.y); + int arr[10]; + for (int i = M; i > N; i--) { // Bounds are M - N - 1 + arr[i+M+N] = arr[i+M+2*N]; // |distance| = N + arr[i+M+2*N] = arr[i+M+N]; // |distance| = N + } +} +void d() { + int M = int(in_vec.x); + int N = int(in_vec.y); + int arr[10]; + for (int i = M; i >= N; i--) { // Bounds are M - N + arr[i+M+N] = arr[i+M+2*N]; // |distance| = N + arr[i+M+2*N] = arr[i+M+N]; // |distance| = N + } +} +void main(){ + a(); + b(); + c(); + d(); +} +*/ +TEST(DependencyAnalysisHelpers, symbolic_to_symbolic) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %20 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %16 "M" + OpName %20 "in_vec" + OpName %27 "N" + OpName %32 "i" + OpName %46 "arr" + OpName %79 "M" + OpName %83 "N" + OpName %87 "i" + OpName %97 "arr" + OpName %128 "M" + OpName %132 "N" + OpName %136 "i" + OpName %146 "arr" + OpName %177 "M" + OpName %181 "N" + OpName %185 "i" + OpName %195 "arr" + OpDecorate %20 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %14 = OpTypeInt 32 1 + %15 = OpTypePointer Function %14 + %17 = OpTypeFloat 32 + %18 = OpTypeVector %17 4 + %19 = OpTypePointer Input %18 + %20 = OpVariable %19 Input + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 0 + %23 = OpTypePointer Input %17 + %28 = OpConstant %21 1 + %41 = OpTypeBool + %43 = OpConstant %21 10 + %44 = OpTypeArray %14 %43 + %45 = OpTypePointer Function %44 + %55 = OpConstant %14 2 + %77 = OpConstant %14 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %226 = OpFunctionCall %2 %6 + %227 = OpFunctionCall %2 %8 + %228 = OpFunctionCall %2 %10 + %229 = OpFunctionCall %2 %12 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %16 = OpVariable %15 Function + %27 = OpVariable %15 Function + %32 = OpVariable %15 Function + %46 = OpVariable %45 Function + %24 = OpAccessChain %23 %20 %22 + %25 = OpLoad %17 %24 + %26 = OpConvertFToS %14 %25 + OpStore %16 %26 + %29 = OpAccessChain %23 %20 %28 + %30 = OpLoad %17 %29 + %31 = OpConvertFToS %14 %30 + OpStore %27 %31 + OpStore %32 %26 + OpBranch %34 + %34 = OpLabel + %230 = OpPhi %14 %26 %7 %78 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %42 = OpSLessThan %41 %230 %31 + OpBranchConditional %42 %35 %36 + %35 = OpLabel + %49 = OpIAdd %14 %230 %26 + %51 = OpIAdd %14 %49 %31 + %54 = OpIAdd %14 %230 %26 + %57 = OpIMul %14 %55 %31 + %58 = OpIAdd %14 %54 %57 + %59 = OpAccessChain %15 %46 %58 + %60 = OpLoad %14 %59 + %61 = OpAccessChain %15 %46 %51 + OpStore %61 %60 + %64 = OpIAdd %14 %230 %26 + %66 = OpIMul %14 %55 %31 + %67 = OpIAdd %14 %64 %66 + %70 = OpIAdd %14 %230 %26 + %72 = OpIAdd %14 %70 %31 + %73 = OpAccessChain %15 %46 %72 + %74 = OpLoad %14 %73 + %75 = OpAccessChain %15 %46 %67 + OpStore %75 %74 + OpBranch %37 + %37 = OpLabel + %78 = OpIAdd %14 %230 %77 + OpStore %32 %78 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %79 = OpVariable %15 Function + %83 = OpVariable %15 Function + %87 = OpVariable %15 Function + %97 = OpVariable %45 Function + %80 = OpAccessChain %23 %20 %22 + %81 = OpLoad %17 %80 + %82 = OpConvertFToS %14 %81 + OpStore %79 %82 + %84 = OpAccessChain %23 %20 %28 + %85 = OpLoad %17 %84 + %86 = OpConvertFToS %14 %85 + OpStore %83 %86 + OpStore %87 %82 + OpBranch %89 + %89 = OpLabel + %231 = OpPhi %14 %82 %9 %127 %92 + OpLoopMerge %91 %92 None + OpBranch %93 + %93 = OpLabel + %96 = OpSLessThanEqual %41 %231 %86 + OpBranchConditional %96 %90 %91 + %90 = OpLabel + %100 = OpIAdd %14 %231 %82 + %102 = OpIAdd %14 %100 %86 + %105 = OpIAdd %14 %231 %82 + %107 = OpIMul %14 %55 %86 + %108 = OpIAdd %14 %105 %107 + %109 = OpAccessChain %15 %97 %108 + %110 = OpLoad %14 %109 + %111 = OpAccessChain %15 %97 %102 + OpStore %111 %110 + %114 = OpIAdd %14 %231 %82 + %116 = OpIMul %14 %55 %86 + %117 = OpIAdd %14 %114 %116 + %120 = OpIAdd %14 %231 %82 + %122 = OpIAdd %14 %120 %86 + %123 = OpAccessChain %15 %97 %122 + %124 = OpLoad %14 %123 + %125 = OpAccessChain %15 %97 %117 + OpStore %125 %124 + OpBranch %92 + %92 = OpLabel + %127 = OpIAdd %14 %231 %77 + OpStore %87 %127 + OpBranch %89 + %91 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %128 = OpVariable %15 Function + %132 = OpVariable %15 Function + %136 = OpVariable %15 Function + %146 = OpVariable %45 Function + %129 = OpAccessChain %23 %20 %22 + %130 = OpLoad %17 %129 + %131 = OpConvertFToS %14 %130 + OpStore %128 %131 + %133 = OpAccessChain %23 %20 %28 + %134 = OpLoad %17 %133 + %135 = OpConvertFToS %14 %134 + OpStore %132 %135 + OpStore %136 %131 + OpBranch %138 + %138 = OpLabel + %232 = OpPhi %14 %131 %11 %176 %141 + OpLoopMerge %140 %141 None + OpBranch %142 + %142 = OpLabel + %145 = OpSGreaterThan %41 %232 %135 + OpBranchConditional %145 %139 %140 + %139 = OpLabel + %149 = OpIAdd %14 %232 %131 + %151 = OpIAdd %14 %149 %135 + %154 = OpIAdd %14 %232 %131 + %156 = OpIMul %14 %55 %135 + %157 = OpIAdd %14 %154 %156 + %158 = OpAccessChain %15 %146 %157 + %159 = OpLoad %14 %158 + %160 = OpAccessChain %15 %146 %151 + OpStore %160 %159 + %163 = OpIAdd %14 %232 %131 + %165 = OpIMul %14 %55 %135 + %166 = OpIAdd %14 %163 %165 + %169 = OpIAdd %14 %232 %131 + %171 = OpIAdd %14 %169 %135 + %172 = OpAccessChain %15 %146 %171 + %173 = OpLoad %14 %172 + %174 = OpAccessChain %15 %146 %166 + OpStore %174 %173 + OpBranch %141 + %141 = OpLabel + %176 = OpISub %14 %232 %77 + OpStore %136 %176 + OpBranch %138 + %140 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %177 = OpVariable %15 Function + %181 = OpVariable %15 Function + %185 = OpVariable %15 Function + %195 = OpVariable %45 Function + %178 = OpAccessChain %23 %20 %22 + %179 = OpLoad %17 %178 + %180 = OpConvertFToS %14 %179 + OpStore %177 %180 + %182 = OpAccessChain %23 %20 %28 + %183 = OpLoad %17 %182 + %184 = OpConvertFToS %14 %183 + OpStore %181 %184 + OpStore %185 %180 + OpBranch %187 + %187 = OpLabel + %233 = OpPhi %14 %180 %13 %225 %190 + OpLoopMerge %189 %190 None + OpBranch %191 + %191 = OpLabel + %194 = OpSGreaterThanEqual %41 %233 %184 + OpBranchConditional %194 %188 %189 + %188 = OpLabel + %198 = OpIAdd %14 %233 %180 + %200 = OpIAdd %14 %198 %184 + %203 = OpIAdd %14 %233 %180 + %205 = OpIMul %14 %55 %184 + %206 = OpIAdd %14 %203 %205 + %207 = OpAccessChain %15 %195 %206 + %208 = OpLoad %14 %207 + %209 = OpAccessChain %15 %195 %200 + OpStore %209 %208 + %212 = OpIAdd %14 %233 %180 + %214 = OpIMul %14 %55 %184 + %215 = OpIAdd %14 %212 %214 + %218 = OpIAdd %14 %233 %180 + %220 = OpIAdd %14 %218 %184 + %221 = OpAccessChain %15 %195 %220 + %222 = OpLoad %14 %221 + %223 = OpAccessChain %15 %195 %215 + OpStore %223 %222 + OpBranch %190 + %190 = OpLabel + %225 = OpISub %14 %233 %77 + OpStore %185 %225 + OpBranch %187 + %189 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + { + // Function a + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 35)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 60 -> 61 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(60) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 74 -> 75 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(74) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function b + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 90)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 110 -> 111 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(110) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 124 -> 125 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(124) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function c + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 139)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 159 -> 160 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(159) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 173 -> 174 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(173) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function d + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 188)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 208 -> 209 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(208) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 222 -> 223 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(222) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/fusion_compatibility.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/fusion_compatibility.cpp new file mode 100644 index 000000000..cda8576c5 --- /dev/null +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/fusion_compatibility.cpp @@ -0,0 +1,1785 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_fusion.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using FusionCompatibilityTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int i = 0; // Can't fuse, i=0 in first & i=10 in second + for (; i < 10; i++) {} + for (; i < 10; i++) {} +} +*/ +TEST_F(FusionCompatibilityTest, SameInductionVariableDifferentBounds) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %31 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %31 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %31 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpBranch %22 + %22 = OpLabel + %32 = OpPhi %6 %31 %12 %30 %25 + OpLoopMerge %24 %25 None + OpBranch %26 + %26 = OpLabel + %28 = OpSLessThan %17 %32 %16 + OpBranchConditional %28 %23 %24 + %23 = OpLabel + OpBranch %25 + %25 = OpLabel + %30 = OpIAdd %6 %32 %20 + OpStore %8 %30 + OpBranch %22 + %24 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 1 +#version 440 core +void main() { + for (int i = 0; i < 10; i++) {} + for (int i = 0; i < 10; i++) {} +} +*/ +TEST_F(FusionCompatibilityTest, Compatible) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %32 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %32 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %32 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %9 + OpBranch %23 + %23 = OpLabel + %33 = OpPhi %6 %9 %12 %31 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %17 %33 %16 + OpBranchConditional %29 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %31 = OpIAdd %6 %33 %20 + OpStore %22 %31 + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 2 +#version 440 core +void main() { + for (int i = 0; i < 10; i++) {} + for (int j = 0; j < 10; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, DifferentName) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %32 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %32 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %32 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %9 + OpBranch %23 + %23 = OpLabel + %33 = OpPhi %6 %9 %12 %31 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %17 %33 %16 + OpBranchConditional %29 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %31 = OpIAdd %6 %33 %20 + OpStore %22 %31 + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + // Can't fuse, different step + for (int i = 0; i < 10; i++) {} + for (int j = 0; j < 10; j=j+2) {} +} + +*/ +TEST_F(FusionCompatibilityTest, SameBoundsDifferentStep) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %31 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %33 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %33 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %33 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %9 + OpBranch %23 + %23 = OpLabel + %34 = OpPhi %6 %9 %12 %32 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %17 %34 %16 + OpBranchConditional %29 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %32 = OpIAdd %6 %34 %31 + OpStore %22 %32 + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 4 +#version 440 core +void main() { + // Can't fuse, different upper bound + for (int i = 0; i < 10; i++) {} + for (int j = 0; j < 20; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, DifferentUpperBound) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %29 = OpConstant %6 20 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %33 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %33 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %33 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %9 + OpBranch %23 + %23 = OpLabel + %34 = OpPhi %6 %9 %12 %32 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %30 = OpSLessThan %17 %34 %29 + OpBranchConditional %30 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %32 = OpIAdd %6 %34 %20 + OpStore %22 %32 + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 5 +#version 440 core +void main() { + // Can't fuse, different lower bound + for (int i = 5; i < 10; i++) {} + for (int j = 0; j < 10; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, DifferentLowerBound) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 5 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %23 = OpConstant %6 0 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %33 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %33 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %33 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %23 + OpBranch %24 + %24 = OpLabel + %34 = OpPhi %6 %23 %12 %32 %27 + OpLoopMerge %26 %27 None + OpBranch %28 + %28 = OpLabel + %30 = OpSLessThan %17 %34 %16 + OpBranchConditional %30 %25 %26 + %25 = OpLabel + OpBranch %27 + %27 = OpLabel + %32 = OpIAdd %6 %34 %20 + OpStore %22 %32 + OpBranch %24 + %26 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 6 +#version 440 core +void main() { + // Can't fuse, break in first loop + for (int i = 0; i < 10; i++) { + if (i == 5) { + break; + } + } + for (int j = 0; j < 10; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, Break) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %28 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 5 + %26 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %28 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %38 = OpPhi %6 %9 %5 %27 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %38 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %21 = OpIEqual %17 %38 %20 + OpSelectionMerge %23 None + OpBranchConditional %21 %22 %23 + %22 = OpLabel + OpBranch %12 + %23 = OpLabel + OpBranch %13 + %13 = OpLabel + %27 = OpIAdd %6 %38 %26 + OpStore %8 %27 + OpBranch %10 + %12 = OpLabel + OpStore %28 %9 + OpBranch %29 + %29 = OpLabel + %39 = OpPhi %6 %9 %12 %37 %32 + OpLoopMerge %31 %32 None + OpBranch %33 + %33 = OpLabel + %35 = OpSLessThan %17 %39 %16 + OpBranchConditional %35 %30 %31 + %30 = OpLabel + OpBranch %32 + %32 = OpLabel + %37 = OpIAdd %6 %39 %26 + OpStore %28 %37 + OpBranch %29 + %31 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +layout(location = 0) in vec4 c; +void main() { + int N = int(c.x); + for (int i = 0; i < N; i++) {} + for (int j = 0; j < N; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, UnknownButSameUpperBound) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %12 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "N" + OpName %12 "c" + OpName %19 "i" + OpName %33 "j" + OpDecorate %12 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpTypeFloat 32 + %10 = OpTypeVector %9 4 + %11 = OpTypePointer Input %10 + %12 = OpVariable %11 Input + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 0 + %15 = OpTypePointer Input %9 + %20 = OpConstant %6 0 + %28 = OpTypeBool + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %33 = OpVariable %7 Function + %16 = OpAccessChain %15 %12 %14 + %17 = OpLoad %9 %16 + %18 = OpConvertFToS %6 %17 + OpStore %8 %18 + OpStore %19 %20 + OpBranch %21 + %21 = OpLabel + %44 = OpPhi %6 %20 %5 %32 %24 + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %29 = OpSLessThan %28 %44 %18 + OpBranchConditional %29 %22 %23 + %22 = OpLabel + OpBranch %24 + %24 = OpLabel + %32 = OpIAdd %6 %44 %31 + OpStore %19 %32 + OpBranch %21 + %23 = OpLabel + OpStore %33 %20 + OpBranch %34 + %34 = OpLabel + %46 = OpPhi %6 %20 %23 %43 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %41 = OpSLessThan %28 %46 %18 + OpBranchConditional %41 %35 %36 + %35 = OpLabel + OpBranch %37 + %37 = OpLabel + %43 = OpIAdd %6 %46 %31 + OpStore %33 %43 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +layout(location = 0) in vec4 c; +void main() { + int N = int(c.x); + for (int i = 0; N > j; i++) {} + for (int j = 0; N > j; j++) {} +} +*/ +TEST_F(FusionCompatibilityTest, UnknownButSameUpperBoundReverseCondition) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %12 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "N" + OpName %12 "c" + OpName %19 "i" + OpName %33 "j" + OpDecorate %12 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpTypeFloat 32 + %10 = OpTypeVector %9 4 + %11 = OpTypePointer Input %10 + %12 = OpVariable %11 Input + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 0 + %15 = OpTypePointer Input %9 + %20 = OpConstant %6 0 + %28 = OpTypeBool + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %33 = OpVariable %7 Function + %16 = OpAccessChain %15 %12 %14 + %17 = OpLoad %9 %16 + %18 = OpConvertFToS %6 %17 + OpStore %8 %18 + OpStore %19 %20 + OpBranch %21 + %21 = OpLabel + %45 = OpPhi %6 %20 %5 %32 %24 + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %29 = OpSGreaterThan %28 %18 %45 + OpBranchConditional %29 %22 %23 + %22 = OpLabel + OpBranch %24 + %24 = OpLabel + %32 = OpIAdd %6 %45 %31 + OpStore %19 %32 + OpBranch %21 + %23 = OpLabel + OpStore %33 %20 + OpBranch %34 + %34 = OpLabel + %47 = OpPhi %6 %20 %23 %43 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %41 = OpSGreaterThan %28 %18 %47 + OpBranchConditional %41 %35 %36 + %35 = OpLabel + OpBranch %37 + %37 = OpLabel + %43 = OpIAdd %6 %47 %31 + OpStore %33 %43 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +layout(location = 0) in vec4 c; +void main() { + // Can't fuse different bound + int N = int(c.x); + for (int i = 0; i < N; i++) {} + for (int j = 0; j < N+1; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, UnknownUpperBoundAddition) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %12 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "N" + OpName %12 "c" + OpName %19 "i" + OpName %33 "j" + OpDecorate %12 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpTypeFloat 32 + %10 = OpTypeVector %9 4 + %11 = OpTypePointer Input %10 + %12 = OpVariable %11 Input + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 0 + %15 = OpTypePointer Input %9 + %20 = OpConstant %6 0 + %28 = OpTypeBool + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %33 = OpVariable %7 Function + %16 = OpAccessChain %15 %12 %14 + %17 = OpLoad %9 %16 + %18 = OpConvertFToS %6 %17 + OpStore %8 %18 + OpStore %19 %20 + OpBranch %21 + %21 = OpLabel + %45 = OpPhi %6 %20 %5 %32 %24 + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %29 = OpSLessThan %28 %45 %18 + OpBranchConditional %29 %22 %23 + %22 = OpLabel + OpBranch %24 + %24 = OpLabel + %32 = OpIAdd %6 %45 %31 + OpStore %19 %32 + OpBranch %21 + %23 = OpLabel + OpStore %33 %20 + OpBranch %34 + %34 = OpLabel + %47 = OpPhi %6 %20 %23 %44 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %41 = OpIAdd %6 %18 %31 + %42 = OpSLessThan %28 %47 %41 + OpBranchConditional %42 %35 %36 + %35 = OpLabel + OpBranch %37 + %37 = OpLabel + %44 = OpIAdd %6 %47 %31 + OpStore %33 %44 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 10 +#version 440 core +void main() { + for (int i = 0; i < 10; i++) {} + for (int j = 0; j < 10; j++) {} + for (int k = 0; k < 10; k++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, SeveralAdjacentLoops) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "j" + OpName %32 "k" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + %32 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %42 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %42 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %42 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %9 + OpBranch %23 + %23 = OpLabel + %43 = OpPhi %6 %9 %12 %31 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %17 %43 %16 + OpBranchConditional %29 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %31 = OpIAdd %6 %43 %20 + OpStore %22 %31 + OpBranch %23 + %25 = OpLabel + OpStore %32 %9 + OpBranch %33 + %33 = OpLabel + %44 = OpPhi %6 %9 %25 %41 %36 + OpLoopMerge %35 %36 None + OpBranch %37 + %37 = OpLabel + %39 = OpSLessThan %17 %44 %16 + OpBranchConditional %39 %34 %35 + %34 = OpLabel + OpBranch %36 + %36 = OpLabel + %41 = OpIAdd %6 %44 %20 + OpStore %32 %41 + OpBranch %33 + %35 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_0).AreCompatible()); + EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_2).AreCompatible()); + EXPECT_FALSE(LoopFusion(context.get(), loop_1, loop_0).AreCompatible()); + EXPECT_TRUE(LoopFusion(context.get(), loop_0, loop_1).AreCompatible()); + EXPECT_TRUE(LoopFusion(context.get(), loop_1, loop_2).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + // Can't fuse, not adjacent + int x = 0; + for (int i = 0; i < 10; i++) { + if (i > 10) { + x++; + } + } + x++; + for (int j = 0; j < 10; j++) {} + for (int k = 0; k < 10; k++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, NonAdjacentLoops) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "x" + OpName %10 "i" + OpName %31 "j" + OpName %41 "k" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %25 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %31 = OpVariable %7 Function + %41 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %52 = OpPhi %6 %9 %5 %56 %14 + %51 = OpPhi %6 %9 %5 %28 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %51 %17 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + %21 = OpSGreaterThan %18 %52 %17 + OpSelectionMerge %23 None + OpBranchConditional %21 %22 %23 + %22 = OpLabel + %26 = OpIAdd %6 %52 %25 + OpStore %8 %26 + OpBranch %23 + %23 = OpLabel + %56 = OpPhi %6 %52 %12 %26 %22 + OpBranch %14 + %14 = OpLabel + %28 = OpIAdd %6 %51 %25 + OpStore %10 %28 + OpBranch %11 + %13 = OpLabel + %30 = OpIAdd %6 %52 %25 + OpStore %8 %30 + OpStore %31 %9 + OpBranch %32 + %32 = OpLabel + %53 = OpPhi %6 %9 %13 %40 %35 + OpLoopMerge %34 %35 None + OpBranch %36 + %36 = OpLabel + %38 = OpSLessThan %18 %53 %17 + OpBranchConditional %38 %33 %34 + %33 = OpLabel + OpBranch %35 + %35 = OpLabel + %40 = OpIAdd %6 %53 %25 + OpStore %31 %40 + OpBranch %32 + %34 = OpLabel + OpStore %41 %9 + OpBranch %42 + %42 = OpLabel + %54 = OpPhi %6 %9 %34 %50 %45 + OpLoopMerge %44 %45 None + OpBranch %46 + %46 = OpLabel + %48 = OpSLessThan %18 %54 %17 + OpBranchConditional %48 %43 %44 + %43 = OpLabel + OpBranch %45 + %45 = OpLabel + %50 = OpIAdd %6 %54 %25 + OpStore %41 %50 + OpBranch %42 + %44 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_0).AreCompatible()); + EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_2).AreCompatible()); + EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_1).AreCompatible()); + EXPECT_TRUE(LoopFusion(context.get(), loop_1, loop_2).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 12 +#version 440 core +void main() { + int j = 0; + int i = 0; + for (; i < 10; i++) {} + for (; j < 10; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, CompatibleInitDeclaredBeforeLoops) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "j" + OpName %10 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %21 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %32 = OpPhi %6 %9 %5 %22 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %32 %17 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + OpBranch %14 + %14 = OpLabel + %22 = OpIAdd %6 %32 %21 + OpStore %10 %22 + OpBranch %11 + %13 = OpLabel + OpBranch %23 + %23 = OpLabel + %33 = OpPhi %6 %9 %13 %31 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %18 %33 %17 + OpBranchConditional %29 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %31 = OpIAdd %6 %33 %21 + OpStore %8 %31 + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + EXPECT_TRUE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 13 regenerate! +#version 440 core +void main() { + int[10] a; + int[10] b; + // Can't fuse, several induction variables + for (int j = 0; j < 10; j++) { + b[i] = a[i]; + } + for (int i = 0, j = 0; i < 10; i++, j = j+2) { + } +} + +*/ +TEST_F(FusionCompatibilityTest, SeveralInductionVariables) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "j" + OpName %23 "b" + OpName %25 "a" + OpName %33 "i" + OpName %34 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %31 = OpConstant %6 1 + %48 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %33 = OpVariable %7 Function + %34 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %50 = OpPhi %6 %9 %5 %32 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %50 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %50 + %28 = OpLoad %6 %27 + %29 = OpAccessChain %7 %23 %50 + OpStore %29 %28 + OpBranch %13 + %13 = OpLabel + %32 = OpIAdd %6 %50 %31 + OpStore %8 %32 + OpBranch %10 + %12 = OpLabel + OpStore %33 %9 + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %52 = OpPhi %6 %9 %12 %49 %38 + %51 = OpPhi %6 %9 %12 %46 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %51 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %44 = OpAccessChain %7 %25 %52 + OpStore %44 %51 + OpBranch %38 + %38 = OpLabel + %46 = OpIAdd %6 %51 %31 + OpStore %33 %46 + %49 = OpIAdd %6 %52 %48 + OpStore %34 %49 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + EXPECT_FALSE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 14 +#version 440 core +void main() { + // Fine + for (int i = 0; i < 10; i = i + 2) {} + for (int j = 0; j < 10; j = j + 2) {} +} + +*/ +TEST_F(FusionCompatibilityTest, CompatibleNonIncrementStep) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "j" + OpName %10 "i" + OpName %11 "i" + OpName %24 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %22 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %11 = OpVariable %7 Function + %24 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpStore %11 %9 + OpBranch %12 + %12 = OpLabel + %34 = OpPhi %6 %9 %5 %23 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %34 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + OpBranch %15 + %15 = OpLabel + %23 = OpIAdd %6 %34 %22 + OpStore %11 %23 + OpBranch %12 + %14 = OpLabel + OpStore %24 %9 + OpBranch %25 + %25 = OpLabel + %35 = OpPhi %6 %9 %14 %33 %28 + OpLoopMerge %27 %28 None + OpBranch %29 + %29 = OpLabel + %31 = OpSLessThan %19 %35 %18 + OpBranchConditional %31 %26 %27 + %26 = OpLabel + OpBranch %28 + %28 = OpLabel + %33 = OpIAdd %6 %35 %22 + OpStore %24 %33 + OpBranch %25 + %27 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + EXPECT_TRUE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 15 +#version 440 core + +int j = 0; + +void main() { + // Not compatible, unknown init for second. + for (int i = 0; i < 10; i = i + 2) {} + for (; j < 10; j = j + 2) {} +} + +*/ +TEST_F(FusionCompatibilityTest, UnknonInitForSecondLoop) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "j" + OpName %11 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Private %6 + %8 = OpVariable %7 Private + %9 = OpConstant %6 0 + %10 = OpTypePointer Function %6 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %22 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %11 = OpVariable %10 Function + OpStore %8 %9 + OpStore %11 %9 + OpBranch %12 + %12 = OpLabel + %33 = OpPhi %6 %9 %5 %23 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %33 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + OpBranch %15 + %15 = OpLabel + %23 = OpIAdd %6 %33 %22 + OpStore %11 %23 + OpBranch %12 + %14 = OpLabel + OpBranch %24 + %24 = OpLabel + OpLoopMerge %26 %27 None + OpBranch %28 + %28 = OpLabel + %29 = OpLoad %6 %8 + %30 = OpSLessThan %19 %29 %18 + OpBranchConditional %30 %25 %26 + %25 = OpLabel + OpBranch %27 + %27 = OpLabel + %31 = OpLoad %6 %8 + %32 = OpIAdd %6 %31 %22 + OpStore %8 %32 + OpBranch %24 + %26 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + EXPECT_FALSE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 16 +#version 440 core +void main() { + // Not compatible, continue in loop 0 + for (int i = 0; i < 10; ++i) { + if (i % 2 == 1) { + continue; + } + } + for (int j = 0; j < 10; ++j) {} +} + +*/ +TEST_F(FusionCompatibilityTest, Continue) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %29 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 2 + %22 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %29 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %39 = OpPhi %6 %9 %5 %28 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %39 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %21 = OpSMod %6 %39 %20 + %23 = OpIEqual %17 %21 %22 + OpSelectionMerge %25 None + OpBranchConditional %23 %24 %25 + %24 = OpLabel + OpBranch %13 + %25 = OpLabel + OpBranch %13 + %13 = OpLabel + %28 = OpIAdd %6 %39 %22 + OpStore %8 %28 + OpBranch %10 + %12 = OpLabel + OpStore %29 %9 + OpBranch %30 + %30 = OpLabel + %40 = OpPhi %6 %9 %12 %38 %33 + OpLoopMerge %32 %33 None + OpBranch %34 + %34 = OpLabel + %36 = OpSLessThan %17 %40 %16 + OpBranchConditional %36 %31 %32 + %31 = OpLabel + OpBranch %33 + %33 = OpLabel + %38 = OpIAdd %6 %40 %22 + OpStore %29 %38 + OpBranch %30 + %32 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + EXPECT_FALSE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + // Compatible + for (int i = 0; i < 10; ++i) { + if (i % 2 == 1) { + } else { + a[i] = i; + } + } + for (int j = 0; j < 10; ++j) {} +} + +*/ +TEST_F(FusionCompatibilityTest, IfElseInLoop) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %31 "a" + OpName %37 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 2 + %22 = OpConstant %6 1 + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypePointer Function %29 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %31 = OpVariable %30 Function + %37 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %47 = OpPhi %6 %9 %5 %36 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %47 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %21 = OpSMod %6 %47 %20 + %23 = OpIEqual %17 %21 %22 + OpSelectionMerge %25 None + OpBranchConditional %23 %24 %26 + %24 = OpLabel + OpBranch %25 + %26 = OpLabel + %34 = OpAccessChain %7 %31 %47 + OpStore %34 %47 + OpBranch %25 + %25 = OpLabel + OpBranch %13 + %13 = OpLabel + %36 = OpIAdd %6 %47 %22 + OpStore %8 %36 + OpBranch %10 + %12 = OpLabel + OpStore %37 %9 + OpBranch %38 + %38 = OpLabel + %48 = OpPhi %6 %9 %12 %46 %41 + OpLoopMerge %40 %41 None + OpBranch %42 + %42 = OpLabel + %44 = OpSLessThan %17 %48 %16 + OpBranchConditional %44 %39 %40 + %39 = OpLabel + OpBranch %41 + %41 = OpLabel + %46 = OpIAdd %6 %48 %22 + OpStore %37 %46 + OpBranch %38 + %40 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + EXPECT_TRUE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible()); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/fusion_illegal.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/fusion_illegal.cpp new file mode 100644 index 000000000..26d54457d --- /dev/null +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/fusion_illegal.cpp @@ -0,0 +1,1592 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_fusion.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using FusionIllegalTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Illegal, loop-independent dependence will become a + // backward loop-carried antidependence + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i+1] + 2; + } +} + +*/ +TEST_F(FusionIllegalTest, PositiveDistanceCreatedRAW) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %42 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %48 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %53 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %53 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %53 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %53 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %53 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %54 = OpPhi %6 %9 %12 %52 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %54 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpIAdd %6 %54 %29 + %46 = OpAccessChain %7 %23 %45 + %47 = OpLoad %6 %46 + %49 = OpIAdd %6 %47 %48 + %50 = OpAccessChain %7 %42 %54 + OpStore %50 %49 + OpBranch %38 + %38 = OpLabel + %52 = OpIAdd %6 %54 %29 + OpStore %34 %52 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core + +int func() { + return 10; +} + +void main() { + int[10] a; + int[10] b; + // Illegal, function call + for (int i = 0; i < 10; i++) { + a[i] = func(); + } + for (int i = 0; i < 10; i++) { + b[i] = a[i]; + } +} +*/ +TEST_F(FusionIllegalTest, FunctionCall) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "func(" + OpName %14 "i" + OpName %28 "a" + OpName %35 "i" + OpName %43 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypeFunction %6 + %10 = OpConstant %6 10 + %13 = OpTypePointer Function %6 + %15 = OpConstant %6 0 + %22 = OpTypeBool + %24 = OpTypeInt 32 0 + %25 = OpConstant %24 10 + %26 = OpTypeArray %6 %25 + %27 = OpTypePointer Function %26 + %33 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %14 = OpVariable %13 Function + %28 = OpVariable %27 Function + %35 = OpVariable %13 Function + %43 = OpVariable %27 Function + OpStore %14 %15 + OpBranch %16 + %16 = OpLabel + %51 = OpPhi %6 %15 %5 %34 %19 + OpLoopMerge %18 %19 None + OpBranch %20 + %20 = OpLabel + %23 = OpSLessThan %22 %51 %10 + OpBranchConditional %23 %17 %18 + %17 = OpLabel + %30 = OpFunctionCall %6 %8 + %31 = OpAccessChain %13 %28 %51 + OpStore %31 %30 + OpBranch %19 + %19 = OpLabel + %34 = OpIAdd %6 %51 %33 + OpStore %14 %34 + OpBranch %16 + %18 = OpLabel + OpStore %35 %15 + OpBranch %36 + %36 = OpLabel + %52 = OpPhi %6 %15 %18 %50 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %22 %52 %10 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %46 = OpAccessChain %13 %28 %52 + %47 = OpLoad %6 %46 + %48 = OpAccessChain %13 %43 %52 + OpStore %48 %47 + OpBranch %39 + %39 = OpLabel + %50 = OpIAdd %6 %52 %33 + OpStore %35 %50 + OpBranch %36 + %38 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %6 None %7 + %9 = OpLabel + OpReturnValue %10 + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 16 +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Illegal outer. + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + c[i][j] = a[i][j] + 2; + } + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + b[i][j] = c[i+1][j] + 10; + } + } +} + +*/ +TEST_F(FusionIllegalTest, PositiveDistanceCreatedRAWOuterLoop) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %32 "c" + OpName %35 "a" + OpName %48 "i" + OpName %56 "j" + OpName %64 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypeArray %29 %28 + %31 = OpTypePointer Function %30 + %40 = OpConstant %6 2 + %44 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %31 Function + %35 = OpVariable %31 Function + %48 = OpVariable %7 Function + %56 = OpVariable %7 Function + %64 = OpVariable %31 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %78 = OpPhi %6 %9 %5 %47 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %78 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %82 = OpPhi %6 %9 %11 %45 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %82 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %38 = OpAccessChain %7 %35 %78 %82 + %39 = OpLoad %6 %38 + %41 = OpIAdd %6 %39 %40 + %42 = OpAccessChain %7 %32 %78 %82 + OpStore %42 %41 + OpBranch %23 + %23 = OpLabel + %45 = OpIAdd %6 %82 %44 + OpStore %19 %45 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %47 = OpIAdd %6 %78 %44 + OpStore %8 %47 + OpBranch %10 + %12 = OpLabel + OpStore %48 %9 + OpBranch %49 + %49 = OpLabel + %79 = OpPhi %6 %9 %12 %77 %52 + OpLoopMerge %51 %52 None + OpBranch %53 + %53 = OpLabel + %55 = OpSLessThan %17 %79 %16 + OpBranchConditional %55 %50 %51 + %50 = OpLabel + OpStore %56 %9 + OpBranch %57 + %57 = OpLabel + %80 = OpPhi %6 %9 %50 %75 %60 + OpLoopMerge %59 %60 None + OpBranch %61 + %61 = OpLabel + %63 = OpSLessThan %17 %80 %16 + OpBranchConditional %63 %58 %59 + %58 = OpLabel + %68 = OpIAdd %6 %79 %44 + %70 = OpAccessChain %7 %32 %68 %80 + %71 = OpLoad %6 %70 + %72 = OpIAdd %6 %71 %16 + %73 = OpAccessChain %7 %64 %79 %80 + OpStore %73 %72 + OpBranch %60 + %60 = OpLabel + %75 = OpIAdd %6 %80 %44 + OpStore %56 %75 + OpBranch %57 + %59 = OpLabel + OpBranch %52 + %52 = OpLabel + %77 = OpIAdd %6 %79 %44 + OpStore %48 %77 + OpBranch %49 + %51 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 4u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + auto loop_3 = loops[3]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_2, loop_3); + EXPECT_FALSE(fusion.AreCompatible()); + } + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 19 +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Illegal, would create a backward loop-carried anti-dependence. + for (int i = 0; i < 10; i++) { + c[i] = a[i] + 1; + } + for (int i = 0; i < 10; i++) { + a[i+1] = c[i] + 2; + } +} + +*/ +TEST_F(FusionIllegalTest, PositiveDistanceCreatedWAR) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "c" + OpName %25 "a" + OpName %34 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %47 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %52 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %52 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %52 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %52 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %52 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %53 = OpPhi %6 %9 %12 %51 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %53 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %43 = OpIAdd %6 %53 %29 + %45 = OpAccessChain %7 %23 %53 + %46 = OpLoad %6 %45 + %48 = OpIAdd %6 %46 %47 + %49 = OpAccessChain %7 %25 %43 + OpStore %49 %48 + OpBranch %38 + %38 = OpLabel + %51 = OpIAdd %6 %53 %29 + OpStore %34 %51 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 21 +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Illegal, would create a backward loop-carried anti-dependence. + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + a[i+1] = c[i+1] + 2; + } +} + +*/ +TEST_F(FusionIllegalTest, PositiveDistanceCreatedWAW) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %44 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %49 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %44 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %54 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %54 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %54 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %54 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %54 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %55 = OpPhi %6 %9 %12 %53 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %55 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %43 = OpIAdd %6 %55 %29 + %46 = OpIAdd %6 %55 %29 + %47 = OpAccessChain %7 %44 %46 + %48 = OpLoad %6 %47 + %50 = OpIAdd %6 %48 %49 + %51 = OpAccessChain %7 %23 %43 + OpStore %51 %50 + OpBranch %38 + %38 = OpLabel + %53 = OpIAdd %6 %55 %29 + OpStore %34 %53 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 28 +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum_0 = 0; + + // Illegal + for (int i = 0; i < 10; i++) { + sum_0 += a[i]; + } + for (int j = 0; j < 10; j++) { + sum_0 += b[j]; + } +} + +*/ +TEST_F(FusionIllegalTest, SameReductionVariable) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum_0" + OpName %10 "i" + OpName %24 "a" + OpName %33 "j" + OpName %41 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 10 + %22 = OpTypeArray %6 %21 + %23 = OpTypePointer Function %22 + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %24 = OpVariable %23 Function + %33 = OpVariable %7 Function + %41 = OpVariable %23 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %52 = OpPhi %6 %9 %5 %29 %14 + %49 = OpPhi %6 %9 %5 %32 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %49 %17 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + %26 = OpAccessChain %7 %24 %49 + %27 = OpLoad %6 %26 + %29 = OpIAdd %6 %52 %27 + OpStore %8 %29 + OpBranch %14 + %14 = OpLabel + %32 = OpIAdd %6 %49 %31 + OpStore %10 %32 + OpBranch %11 + %13 = OpLabel + OpStore %33 %9 + OpBranch %34 + %34 = OpLabel + %51 = OpPhi %6 %52 %13 %46 %37 + %50 = OpPhi %6 %9 %13 %48 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %40 = OpSLessThan %18 %50 %17 + OpBranchConditional %40 %35 %36 + %35 = OpLabel + %43 = OpAccessChain %7 %41 %50 + %44 = OpLoad %6 %43 + %46 = OpIAdd %6 %51 %44 + OpStore %8 %46 + OpBranch %37 + %37 = OpLabel + %48 = OpIAdd %6 %50 %31 + OpStore %33 %48 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 28 +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum_0 = 0; + + // Illegal + for (int i = 0; i < 10; i++) { + sum_0 += a[i]; + } + for (int j = 0; j < 10; j++) { + sum_0 += b[j]; + } +} + +*/ +TEST_F(FusionIllegalTest, SameReductionVariableLCSSA) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum_0" + OpName %10 "i" + OpName %24 "a" + OpName %33 "j" + OpName %41 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 10 + %22 = OpTypeArray %6 %21 + %23 = OpTypePointer Function %22 + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %24 = OpVariable %23 Function + %33 = OpVariable %7 Function + %41 = OpVariable %23 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %52 = OpPhi %6 %9 %5 %29 %14 + %49 = OpPhi %6 %9 %5 %32 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %49 %17 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + %26 = OpAccessChain %7 %24 %49 + %27 = OpLoad %6 %26 + %29 = OpIAdd %6 %52 %27 + OpStore %8 %29 + OpBranch %14 + %14 = OpLabel + %32 = OpIAdd %6 %49 %31 + OpStore %10 %32 + OpBranch %11 + %13 = OpLabel + OpStore %33 %9 + OpBranch %34 + %34 = OpLabel + %51 = OpPhi %6 %52 %13 %46 %37 + %50 = OpPhi %6 %9 %13 %48 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %40 = OpSLessThan %18 %50 %17 + OpBranchConditional %40 %35 %36 + %35 = OpLabel + %43 = OpAccessChain %7 %41 %50 + %44 = OpLoad %6 %43 + %46 = OpIAdd %6 %51 %44 + OpStore %8 %46 + OpBranch %37 + %37 = OpLabel + %48 = OpIAdd %6 %50 %31 + OpStore %33 %48 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopUtils utils_0(context.get(), loops[0]); + utils_0.MakeLoopClosedSSA(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 30 +#version 440 core +int x; +void main() { + int[10] a; + int[10] b; + + // Illegal, x is unknown. + for (int i = 0; i < 10; i++) { + a[x] = a[i]; + } + for (int j = 0; j < 10; j++) { + a[j] = b[j]; + } +} + +*/ +TEST_F(FusionIllegalTest, UnknownIndexVariable) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "x" + OpName %34 "j" + OpName %43 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %24 = OpTypePointer Private %6 + %25 = OpVariable %24 Private + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %34 = OpVariable %7 Function + %43 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %50 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %50 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpLoad %6 %25 + %28 = OpAccessChain %7 %23 %50 + %29 = OpLoad %6 %28 + %30 = OpAccessChain %7 %23 %26 + OpStore %30 %29 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %50 %32 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %51 = OpPhi %6 %9 %12 %49 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %51 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %43 %51 + %46 = OpLoad %6 %45 + %47 = OpAccessChain %7 %23 %51 + OpStore %47 %46 + OpBranch %38 + %38 = OpLabel + %49 = OpIAdd %6 %51 %32 + OpStore %34 %49 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum = 0; + + // Illegal, accumulator used for indexing. + for (int i = 0; i < 10; i++) { + sum += a[i]; + b[sum] = a[i]; + } + for (int j = 0; j < 10; j++) { + b[j] = b[j]+1; + } +} + +*/ +TEST_F(FusionIllegalTest, AccumulatorIndexing) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum" + OpName %10 "i" + OpName %24 "a" + OpName %30 "b" + OpName %39 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 10 + %22 = OpTypeArray %6 %21 + %23 = OpTypePointer Function %22 + %37 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %24 = OpVariable %23 Function + %30 = OpVariable %23 Function + %39 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %57 = OpPhi %6 %9 %5 %29 %14 + %55 = OpPhi %6 %9 %5 %38 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %55 %17 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + %26 = OpAccessChain %7 %24 %55 + %27 = OpLoad %6 %26 + %29 = OpIAdd %6 %57 %27 + OpStore %8 %29 + %33 = OpAccessChain %7 %24 %55 + %34 = OpLoad %6 %33 + %35 = OpAccessChain %7 %30 %29 + OpStore %35 %34 + OpBranch %14 + %14 = OpLabel + %38 = OpIAdd %6 %55 %37 + OpStore %10 %38 + OpBranch %11 + %13 = OpLabel + OpStore %39 %9 + OpBranch %40 + %40 = OpLabel + %56 = OpPhi %6 %9 %13 %54 %43 + OpLoopMerge %42 %43 None + OpBranch %44 + %44 = OpLabel + %46 = OpSLessThan %18 %56 %17 + OpBranchConditional %46 %41 %42 + %41 = OpLabel + %49 = OpAccessChain %7 %30 %56 + %50 = OpLoad %6 %49 + %51 = OpIAdd %6 %50 %37 + %52 = OpAccessChain %7 %30 %56 + OpStore %52 %51 + OpBranch %43 + %43 = OpLabel + %54 = OpIAdd %6 %56 %37 + OpStore %39 %54 + OpBranch %40 + %42 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 33 +#version 440 core +void main() { + int[10] a; + int[10] b; + + // Illegal, barrier. + for (int i = 0; i < 10; i++) { + a[i] = a[i] * 2; + memoryBarrier(); + } + for (int j = 0; j < 10; j++) { + b[j] = b[j] + 1; + } +} + +*/ +TEST_F(FusionIllegalTest, Barrier) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %36 "j" + OpName %44 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %28 = OpConstant %6 2 + %31 = OpConstant %19 1 + %32 = OpConstant %19 3400 + %34 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %36 = OpVariable %7 Function + %44 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %53 = OpPhi %6 %9 %5 %35 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %53 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpAccessChain %7 %23 %53 + %27 = OpLoad %6 %26 + %29 = OpIMul %6 %27 %28 + %30 = OpAccessChain %7 %23 %53 + OpStore %30 %29 + OpMemoryBarrier %31 %32 + OpBranch %13 + %13 = OpLabel + %35 = OpIAdd %6 %53 %34 + OpStore %8 %35 + OpBranch %10 + %12 = OpLabel + OpStore %36 %9 + OpBranch %37 + %37 = OpLabel + %54 = OpPhi %6 %9 %12 %52 %40 + OpLoopMerge %39 %40 None + OpBranch %41 + %41 = OpLabel + %43 = OpSLessThan %17 %54 %16 + OpBranchConditional %43 %38 %39 + %38 = OpLabel + %47 = OpAccessChain %7 %44 %54 + %48 = OpLoad %6 %47 + %49 = OpIAdd %6 %48 %34 + %50 = OpAccessChain %7 %44 %54 + OpStore %50 %49 + OpBranch %40 + %40 = OpLabel + %52 = OpIAdd %6 %54 %34 + OpStore %36 %52 + OpBranch %37 + %39 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +struct TestStruct { + int[10] a; + int b; +}; + +void main() { + TestStruct test_0; + TestStruct test_1; + + for (int i = 0; i < 10; i++) { + test_0.a[i] = i; + } + for (int j = 0; j < 10; j++) { + test_0 = test_1; + } +} + +*/ +TEST_F(FusionIllegalTest, ArrayInStruct) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "TestStruct" + OpMemberName %22 0 "a" + OpMemberName %22 1 "b" + OpName %24 "test_0" + OpName %31 "j" + OpName %39 "test_1" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypeStruct %21 %6 + %23 = OpTypePointer Function %22 + %29 = OpConstant %6 1 + %47 = OpUndef %22 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %24 = OpVariable %23 Function + %31 = OpVariable %7 Function + %39 = OpVariable %23 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %43 = OpPhi %6 %9 %5 %30 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %43 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %24 %9 %43 + OpStore %27 %43 + OpBranch %13 + %13 = OpLabel + %30 = OpIAdd %6 %43 %29 + OpStore %8 %30 + OpBranch %10 + %12 = OpLabel + OpStore %31 %9 + OpBranch %32 + %32 = OpLabel + %44 = OpPhi %6 %9 %12 %42 %35 + OpLoopMerge %34 %35 None + OpBranch %36 + %36 = OpLabel + %38 = OpSLessThan %17 %44 %16 + OpBranchConditional %38 %33 %34 + %33 = OpLabel + OpStore %24 %47 + OpBranch %35 + %35 = OpLabel + %42 = OpIAdd %6 %44 %29 + OpStore %31 %42 + OpBranch %32 + %34 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 450 + +struct P {float x,y,z;}; +uniform G { int a; P b[2]; int c; } g; +layout(location = 0) out float o; + +void main() +{ + P p[2]; + for (int i = 0; i < 2; ++i) { + p = g.b; + } + for (int j = 0; j < 2; ++j) { + o = p[g.a].x; + } +} + +*/ +TEST_F(FusionIllegalTest, NestedAccessChain) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %64 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 450 + OpName %4 "main" + OpName %8 "i" + OpName %20 "P" + OpMemberName %20 0 "x" + OpMemberName %20 1 "y" + OpMemberName %20 2 "z" + OpName %25 "p" + OpName %26 "P" + OpMemberName %26 0 "x" + OpMemberName %26 1 "y" + OpMemberName %26 2 "z" + OpName %28 "G" + OpMemberName %28 0 "a" + OpMemberName %28 1 "b" + OpMemberName %28 2 "c" + OpName %30 "g" + OpName %55 "j" + OpName %64 "o" + OpMemberDecorate %26 0 Offset 0 + OpMemberDecorate %26 1 Offset 4 + OpMemberDecorate %26 2 Offset 8 + OpDecorate %27 ArrayStride 16 + OpMemberDecorate %28 0 Offset 0 + OpMemberDecorate %28 1 Offset 16 + OpMemberDecorate %28 2 Offset 48 + OpDecorate %28 Block + OpDecorate %30 DescriptorSet 0 + OpDecorate %64 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 2 + %17 = OpTypeBool + %19 = OpTypeFloat 32 + %20 = OpTypeStruct %19 %19 %19 + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 2 + %23 = OpTypeArray %20 %22 + %24 = OpTypePointer Function %23 + %26 = OpTypeStruct %19 %19 %19 + %27 = OpTypeArray %26 %22 + %28 = OpTypeStruct %6 %27 %6 + %29 = OpTypePointer Uniform %28 + %30 = OpVariable %29 Uniform + %31 = OpConstant %6 1 + %32 = OpTypePointer Uniform %27 + %36 = OpTypePointer Function %20 + %39 = OpTypePointer Function %19 + %63 = OpTypePointer Output %19 + %64 = OpVariable %63 Output + %65 = OpTypePointer Uniform %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %25 = OpVariable %24 Function + %55 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %72 = OpPhi %6 %9 %5 %54 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %72 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %33 = OpAccessChain %32 %30 %31 + %34 = OpLoad %27 %33 + %35 = OpCompositeExtract %26 %34 0 + %37 = OpAccessChain %36 %25 %9 + %38 = OpCompositeExtract %19 %35 0 + %40 = OpAccessChain %39 %37 %9 + OpStore %40 %38 + %41 = OpCompositeExtract %19 %35 1 + %42 = OpAccessChain %39 %37 %31 + OpStore %42 %41 + %43 = OpCompositeExtract %19 %35 2 + %44 = OpAccessChain %39 %37 %16 + OpStore %44 %43 + %45 = OpCompositeExtract %26 %34 1 + %46 = OpAccessChain %36 %25 %31 + %47 = OpCompositeExtract %19 %45 0 + %48 = OpAccessChain %39 %46 %9 + OpStore %48 %47 + %49 = OpCompositeExtract %19 %45 1 + %50 = OpAccessChain %39 %46 %31 + OpStore %50 %49 + %51 = OpCompositeExtract %19 %45 2 + %52 = OpAccessChain %39 %46 %16 + OpStore %52 %51 + OpBranch %13 + %13 = OpLabel + %54 = OpIAdd %6 %72 %31 + OpStore %8 %54 + OpBranch %10 + %12 = OpLabel + OpStore %55 %9 + OpBranch %56 + %56 = OpLabel + %73 = OpPhi %6 %9 %12 %71 %59 + OpLoopMerge %58 %59 None + OpBranch %60 + %60 = OpLabel + %62 = OpSLessThan %17 %73 %16 + OpBranchConditional %62 %57 %58 + %57 = OpLabel + %66 = OpAccessChain %65 %30 %9 + %67 = OpLoad %6 %66 + %68 = OpAccessChain %39 %25 %67 %9 + %69 = OpLoad %19 %68 + OpStore %64 %69 + OpBranch %59 + %59 = OpLabel + %71 = OpIAdd %6 %73 %31 + OpStore %55 %71 + OpBranch %56 + %58 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/fusion_legal.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/fusion_legal.cpp new file mode 100644 index 000000000..509516f80 --- /dev/null +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/fusion_legal.cpp @@ -0,0 +1,4587 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_fusion.h" +#include "test/opt/pass_fixture.h" + +#ifdef SPIRV_EFFCEE +#include "effcee/effcee.h" +#endif + +namespace spvtools { +namespace opt { +namespace { + +using FusionLegalTest = PassTest<::testing::Test>; + +bool Validate(const std::vector& bin) { + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2; + spv_context spvContext = spvContextCreate(target_env); + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t binary = {bin.data(), bin.size()}; + spv_result_t error = spvValidate(spvContext, &binary, &diagnostic); + if (error != 0) spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(spvContext); + return error == 0; +} + +void Match(const std::string& checks, IRContext* context) { + // Silence unused warnings with !defined(SPIRV_EFFCE) + (void)checks; + + std::vector bin; + context->module()->ToBinary(&bin, true); + EXPECT_TRUE(Validate(bin)); +#ifdef SPIRV_EFFCEE + std::string assembly; + SpirvTools tools(SPV_ENV_UNIVERSAL_1_2); + EXPECT_TRUE( + tools.Disassemble(bin, &assembly, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER)) + << "Disassembling failed for shader:\n" + << assembly << std::endl; + auto match_result = effcee::Match(assembly, checks); + EXPECT_EQ(effcee::Result::Status::Ok, match_result.status()) + << match_result.message() << "\nChecking result:\n" + << assembly; +#else // ! SPIRV_EFFCEE + (void)checks; +#endif +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + // No dependence, legal + for (int i = 0; i < 10; i++) { + a[i] = a[i]*2; + } + for (int i = 0; i < 10; i++) { + b[i] = b[i]+2; + } +} + +*/ +TEST_F(FusionLegalTest, DifferentArraysInLoops) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %34 "i" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %28 = OpConstant %6 2 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpAccessChain %7 %23 %51 + %27 = OpLoad %6 %26 + %29 = OpIMul %6 %27 %28 + %30 = OpAccessChain %7 %23 %51 + OpStore %30 %29 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %51 %32 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %52 = OpPhi %6 %9 %12 %50 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %52 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %42 %52 + %46 = OpLoad %6 %45 + %47 = OpIAdd %6 %46 %28 + %48 = OpAccessChain %7 %42 %52 + OpStore %48 %47 + OpBranch %38 + %38 = OpLabel + %50 = OpIAdd %6 %52 %32 + OpStore %34 %50 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Only loads to the same array, legal + for (int i = 0; i < 10; i++) { + b[i] = a[i]*2; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i]+2; + } +} + +*/ +TEST_F(FusionLegalTest, OnlyLoadsToSameArray) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "b" + OpName %25 "a" + OpName %35 "i" + OpName %43 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 2 + %33 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %35 = OpVariable %7 Function + %43 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %52 = OpPhi %6 %9 %5 %34 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %52 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %52 + %28 = OpLoad %6 %27 + %30 = OpIMul %6 %28 %29 + %31 = OpAccessChain %7 %23 %52 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %34 = OpIAdd %6 %52 %33 + OpStore %8 %34 + OpBranch %10 + %12 = OpLabel + OpStore %35 %9 + OpBranch %36 + %36 = OpLabel + %53 = OpPhi %6 %9 %12 %51 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %17 %53 %16 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %46 = OpAccessChain %7 %25 %53 + %47 = OpLoad %6 %46 + %48 = OpIAdd %6 %47 %29 + %49 = OpAccessChain %7 %43 %53 + OpStore %49 %48 + OpBranch %39 + %39 = OpLabel + %51 = OpIAdd %6 %53 %33 + OpStore %35 %51 + OpBranch %36 + %38 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + // No loop-carried dependences, legal + for (int i = 0; i < 10; i++) { + a[i] = a[i]*2; + } + for (int i = 0; i < 10; i++) { + b[i] = a[i]+2; + } +} + +*/ +TEST_F(FusionLegalTest, NoLoopCarriedDependences) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %34 "i" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %28 = OpConstant %6 2 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpAccessChain %7 %23 %51 + %27 = OpLoad %6 %26 + %29 = OpIMul %6 %27 %28 + %30 = OpAccessChain %7 %23 %51 + OpStore %30 %29 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %51 %32 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %52 = OpPhi %6 %9 %12 %50 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %52 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %23 %52 + %46 = OpLoad %6 %45 + %47 = OpIAdd %6 %46 %28 + %48 = OpAccessChain %7 %42 %52 + OpStore %48 %47 + OpBranch %38 + %38 = OpLabel + %50 = OpIAdd %6 %52 %32 + OpStore %34 %50 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Parallelism inhibiting, but legal. + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i] + c[i-1]; + } +} + +*/ +TEST_F(FusionLegalTest, ExistingLoopCarriedDependence) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %42 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %55 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %55 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %55 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %55 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %55 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %56 = OpPhi %6 %9 %12 %54 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %56 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %23 %56 + %46 = OpLoad %6 %45 + %48 = OpISub %6 %56 %29 + %49 = OpAccessChain %7 %42 %48 + %50 = OpLoad %6 %49 + %51 = OpIAdd %6 %46 %50 + %52 = OpAccessChain %7 %42 %56 + OpStore %52 %51 + OpBranch %38 + %38 = OpLabel + %54 = OpIAdd %6 %56 %29 + OpStore %34 %54 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[I_1:%\w+]] = OpISub {{%\w+}} [[PHI]] {{%\w+}} +CHECK-NEXT: [[LOAD_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_2]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Creates a loop-carried dependence, but negative, so legal + for (int i = 0; i < 10; i++) { + a[i+1] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, NegativeDistanceCreatedRAW) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %27 "b" + OpName %35 "i" + OpName %43 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %25 = OpConstant %6 1 + %48 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %27 = OpVariable %22 Function + %35 = OpVariable %7 Function + %43 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %53 = OpPhi %6 %9 %5 %34 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %53 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpIAdd %6 %53 %25 + %29 = OpAccessChain %7 %27 %53 + %30 = OpLoad %6 %29 + %31 = OpIAdd %6 %30 %25 + %32 = OpAccessChain %7 %23 %26 + OpStore %32 %31 + OpBranch %13 + %13 = OpLabel + %34 = OpIAdd %6 %53 %25 + OpStore %8 %34 + OpBranch %10 + %12 = OpLabel + OpStore %35 %9 + OpBranch %36 + %36 = OpLabel + %54 = OpPhi %6 %9 %12 %52 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %17 %54 %16 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %46 = OpAccessChain %7 %23 %54 + %47 = OpLoad %6 %46 + %49 = OpIAdd %6 %47 %48 + %50 = OpAccessChain %7 %43 %54 + OpStore %50 %49 + OpBranch %39 + %39 = OpLabel + %52 = OpIAdd %6 %54 %25 + OpStore %35 %52 + OpBranch %36 + %38 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + auto& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal + for (int i = 0; i < 10; i++) { + a[i+1] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i+1] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, NoLoopCarriedDependencesAdjustedIndex) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %27 "b" + OpName %35 "i" + OpName %43 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %25 = OpConstant %6 1 + %49 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %27 = OpVariable %22 Function + %35 = OpVariable %7 Function + %43 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %54 = OpPhi %6 %9 %5 %34 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %54 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpIAdd %6 %54 %25 + %29 = OpAccessChain %7 %27 %54 + %30 = OpLoad %6 %29 + %31 = OpIAdd %6 %30 %25 + %32 = OpAccessChain %7 %23 %26 + OpStore %32 %31 + OpBranch %13 + %13 = OpLabel + %34 = OpIAdd %6 %54 %25 + OpStore %8 %34 + OpBranch %10 + %12 = OpLabel + OpStore %35 %9 + OpBranch %36 + %36 = OpLabel + %55 = OpPhi %6 %9 %12 %53 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %17 %55 %16 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %46 = OpIAdd %6 %55 %25 + %47 = OpAccessChain %7 %23 %46 + %48 = OpLoad %6 %47 + %50 = OpIAdd %6 %48 %49 + %51 = OpAccessChain %7 %43 %55 + OpStore %51 %50 + OpBranch %39 + %39 = OpLabel + %53 = OpIAdd %6 %55 %25 + OpStore %35 %53 + OpBranch %36 + %38 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK-NEXT: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal, independent locations in |a|, SIV + for (int i = 0; i < 10; i++) { + a[2*i+1] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[2*i] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, IndependentSIV) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %29 "b" + OpName %37 "i" + OpName %45 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %24 = OpConstant %6 2 + %27 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %29 = OpVariable %22 Function + %37 = OpVariable %7 Function + %45 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %55 = OpPhi %6 %9 %5 %36 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %55 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpIMul %6 %24 %55 + %28 = OpIAdd %6 %26 %27 + %31 = OpAccessChain %7 %29 %55 + %32 = OpLoad %6 %31 + %33 = OpIAdd %6 %32 %27 + %34 = OpAccessChain %7 %23 %28 + OpStore %34 %33 + OpBranch %13 + %13 = OpLabel + %36 = OpIAdd %6 %55 %27 + OpStore %8 %36 + OpBranch %10 + %12 = OpLabel + OpStore %37 %9 + OpBranch %38 + %38 = OpLabel + %56 = OpPhi %6 %9 %12 %54 %41 + OpLoopMerge %40 %41 None + OpBranch %42 + %42 = OpLabel + %44 = OpSLessThan %17 %56 %16 + OpBranchConditional %44 %39 %40 + %39 = OpLabel + %48 = OpIMul %6 %24 %56 + %49 = OpAccessChain %7 %23 %48 + %50 = OpLoad %6 %49 + %51 = OpIAdd %6 %50 %24 + %52 = OpAccessChain %7 %45 %56 + OpStore %52 %51 + OpBranch %41 + %41 = OpLabel + %54 = OpIAdd %6 %56 %27 + OpStore %37 %54 + OpBranch %38 + %40 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[I_2:%\w+]] = OpIMul {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[I_2_1:%\w+]] = OpIAdd {{%\w+}} [[I_2]] {{%\w+}} +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_2_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[I_2:%\w+]] = OpIMul {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_2]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal, independent locations in |a|, ZIV + for (int i = 0; i < 10; i++) { + a[1] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[9] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, IndependentZIV) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %33 "i" + OpName %41 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %24 = OpConstant %6 1 + %43 = OpConstant %6 9 + %46 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %33 = OpVariable %7 Function + %41 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %32 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %51 + %28 = OpLoad %6 %27 + %29 = OpIAdd %6 %28 %24 + %30 = OpAccessChain %7 %23 %24 + OpStore %30 %29 + OpBranch %13 + %13 = OpLabel + %32 = OpIAdd %6 %51 %24 + OpStore %8 %32 + OpBranch %10 + %12 = OpLabel + OpStore %33 %9 + OpBranch %34 + %34 = OpLabel + %52 = OpPhi %6 %9 %12 %50 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %40 = OpSLessThan %17 %52 %16 + OpBranchConditional %40 %35 %36 + %35 = OpLabel + %44 = OpAccessChain %7 %23 %43 + %45 = OpLoad %6 %44 + %47 = OpIAdd %6 %45 %46 + %48 = OpAccessChain %7 %41 %52 + OpStore %48 %47 + OpBranch %37 + %37 = OpLabel + %50 = OpIAdd %6 %52 %24 + OpStore %33 %50 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK-NOT: OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK: OpStore +CHECK-NOT: OpPhi +CHECK-NOT: OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK: OpLoad +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[20] a; + int[10] b; + int[10] c; + // Legal, non-overlapping sections in |a| + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i+10] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, NonOverlappingAccesses) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %28 "b" + OpName %37 "i" + OpName %45 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 20 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %25 = OpConstant %19 10 + %26 = OpTypeArray %6 %25 + %27 = OpTypePointer Function %26 + %32 = OpConstant %6 1 + %51 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %28 = OpVariable %27 Function + %37 = OpVariable %7 Function + %45 = OpVariable %27 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %56 = OpPhi %6 %9 %5 %36 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %56 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %30 = OpAccessChain %7 %28 %56 + %31 = OpLoad %6 %30 + %33 = OpIAdd %6 %31 %32 + %34 = OpAccessChain %7 %23 %56 + OpStore %34 %33 + OpBranch %13 + %13 = OpLabel + %36 = OpIAdd %6 %56 %32 + OpStore %8 %36 + OpBranch %10 + %12 = OpLabel + OpStore %37 %9 + OpBranch %38 + %38 = OpLabel + %57 = OpPhi %6 %9 %12 %55 %41 + OpLoopMerge %40 %41 None + OpBranch %42 + %42 = OpLabel + %44 = OpSLessThan %17 %57 %16 + OpBranchConditional %44 %39 %40 + %39 = OpLabel + %48 = OpIAdd %6 %57 %16 + %49 = OpAccessChain %7 %23 %48 + %50 = OpLoad %6 %49 + %52 = OpIAdd %6 %50 %51 + %53 = OpAccessChain %7 %45 %57 + OpStore %53 %52 + OpBranch %41 + %41 = OpLabel + %55 = OpIAdd %6 %57 %32 + OpStore %37 %55 + OpBranch %38 + %40 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NOT: OpPhi +CHECK: [[I_10:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK-NEXT: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_10]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal, 3 adjacent loops + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i] + 2; + } + for (int i = 0; i < 10; i++) { + b[i] = c[i] + 10; + } +} + +*/ +TEST_F(FusionLegalTest, AdjacentLoops) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %42 "c" + OpName %52 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %47 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + %52 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %68 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %68 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %68 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %68 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %68 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %69 = OpPhi %6 %9 %12 %51 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %69 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %23 %69 + %46 = OpLoad %6 %45 + %48 = OpIAdd %6 %46 %47 + %49 = OpAccessChain %7 %42 %69 + OpStore %49 %48 + OpBranch %38 + %38 = OpLabel + %51 = OpIAdd %6 %69 %29 + OpStore %34 %51 + OpBranch %35 + %37 = OpLabel + OpStore %52 %9 + OpBranch %53 + %53 = OpLabel + %70 = OpPhi %6 %9 %37 %67 %56 + OpLoopMerge %55 %56 None + OpBranch %57 + %57 = OpLabel + %59 = OpSLessThan %17 %70 %16 + OpBranchConditional %59 %54 %55 + %54 = OpLabel + %62 = OpAccessChain %7 %42 %70 + %63 = OpLoad %6 %62 + %64 = OpIAdd %6 %63 %16 + %65 = OpAccessChain %7 %25 %70 + OpStore %65 %64 + OpBranch %56 + %56 = OpLabel + %67 = OpIAdd %6 %70 %29 + OpStore %52 %67 + OpBranch %53 + %55 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[1], loops[2]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_1]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_2]] +CHECK: [[STORE_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_2]] + )"; + + Match(checks, context.get()); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + std::string checks_ = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_2]] +CHECK: [[STORE_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_2]] + )"; + + Match(checks_, context.get()); + + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Legal inner loop fusion + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + c[i][j] = a[i][j] + 2; + } + for (int j = 0; j < 10; j++) { + b[i][j] = c[i][j] + 10; + } + } +} + +*/ +TEST_F(FusionLegalTest, InnerLoopFusion) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %32 "c" + OpName %35 "a" + OpName %46 "j" + OpName %54 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypeArray %29 %28 + %31 = OpTypePointer Function %30 + %40 = OpConstant %6 2 + %44 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %31 Function + %35 = OpVariable %31 Function + %46 = OpVariable %7 Function + %54 = OpVariable %31 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %67 = OpPhi %6 %9 %5 %66 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %67 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %68 = OpPhi %6 %9 %11 %45 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %68 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %38 = OpAccessChain %7 %35 %67 %68 + %39 = OpLoad %6 %38 + %41 = OpIAdd %6 %39 %40 + %42 = OpAccessChain %7 %32 %67 %68 + OpStore %42 %41 + OpBranch %23 + %23 = OpLabel + %45 = OpIAdd %6 %68 %44 + OpStore %19 %45 + OpBranch %20 + %22 = OpLabel + OpStore %46 %9 + OpBranch %47 + %47 = OpLabel + %69 = OpPhi %6 %9 %22 %64 %50 + OpLoopMerge %49 %50 None + OpBranch %51 + %51 = OpLabel + %53 = OpSLessThan %17 %69 %16 + OpBranchConditional %53 %48 %49 + %48 = OpLabel + %59 = OpAccessChain %7 %32 %67 %69 + %60 = OpLoad %6 %59 + %61 = OpIAdd %6 %60 %16 + %62 = OpAccessChain %7 %54 %67 %69 + OpStore %62 %61 + OpBranch %50 + %50 = OpLabel + %64 = OpIAdd %6 %69 %44 + OpStore %46 %64 + OpBranch %47 + %49 = OpLabel + OpBranch %13 + %13 = OpLabel + %66 = OpIAdd %6 %67 %44 + OpStore %8 %66 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 2u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 12 +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Legal both + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + c[i][j] = a[i][j] + 2; + } + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + b[i][j] = c[i][j] + 10; + } + } +} + +*/ +TEST_F(FusionLegalTest, OuterAndInnerLoop) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %32 "c" + OpName %35 "a" + OpName %48 "i" + OpName %56 "j" + OpName %64 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypeArray %29 %28 + %31 = OpTypePointer Function %30 + %40 = OpConstant %6 2 + %44 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %31 Function + %35 = OpVariable %31 Function + %48 = OpVariable %7 Function + %56 = OpVariable %7 Function + %64 = OpVariable %31 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %77 = OpPhi %6 %9 %5 %47 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %77 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %81 = OpPhi %6 %9 %11 %45 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %81 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %38 = OpAccessChain %7 %35 %77 %81 + %39 = OpLoad %6 %38 + %41 = OpIAdd %6 %39 %40 + %42 = OpAccessChain %7 %32 %77 %81 + OpStore %42 %41 + OpBranch %23 + %23 = OpLabel + %45 = OpIAdd %6 %81 %44 + OpStore %19 %45 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %47 = OpIAdd %6 %77 %44 + OpStore %8 %47 + OpBranch %10 + %12 = OpLabel + OpStore %48 %9 + OpBranch %49 + %49 = OpLabel + %78 = OpPhi %6 %9 %12 %76 %52 + OpLoopMerge %51 %52 None + OpBranch %53 + %53 = OpLabel + %55 = OpSLessThan %17 %78 %16 + OpBranchConditional %55 %50 %51 + %50 = OpLabel + OpStore %56 %9 + OpBranch %57 + %57 = OpLabel + %79 = OpPhi %6 %9 %50 %74 %60 + OpLoopMerge %59 %60 None + OpBranch %61 + %61 = OpLabel + %63 = OpSLessThan %17 %79 %16 + OpBranchConditional %63 %58 %59 + %58 = OpLabel + %69 = OpAccessChain %7 %32 %78 %79 + %70 = OpLoad %6 %69 + %71 = OpIAdd %6 %70 %16 + %72 = OpAccessChain %7 %64 %78 %79 + OpStore %72 %71 + OpBranch %60 + %60 = OpLabel + %74 = OpIAdd %6 %79 %44 + OpStore %56 %74 + OpBranch %57 + %59 = OpLabel + OpBranch %52 + %52 = OpLabel + %76 = OpIAdd %6 %78 %44 + OpStore %48 %76 + OpBranch %49 + %51 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 4u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + auto loop_3 = loops[3]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_2, loop_3); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_3); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK: [[PHI_2:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_2]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_2]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + auto& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + auto& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Legal both, more complex + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + if (i % 2 == 0 && j % 2 == 0) { + c[i][j] = a[i][j] + 2; + } + } + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + b[i][j] = c[i][j] + 10; + } + } +} + +*/ +TEST_F(FusionLegalTest, OuterAndInnerLoopMoreComplex) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %44 "c" + OpName %47 "a" + OpName %59 "i" + OpName %67 "j" + OpName %75 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %28 = OpConstant %6 2 + %39 = OpTypeInt 32 0 + %40 = OpConstant %39 10 + %41 = OpTypeArray %6 %40 + %42 = OpTypeArray %41 %40 + %43 = OpTypePointer Function %42 + %55 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %44 = OpVariable %43 Function + %47 = OpVariable %43 Function + %59 = OpVariable %7 Function + %67 = OpVariable %7 Function + %75 = OpVariable %43 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %88 = OpPhi %6 %9 %5 %58 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %88 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %92 = OpPhi %6 %9 %11 %56 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %92 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %29 = OpSMod %6 %88 %28 + %30 = OpIEqual %17 %29 %9 + OpSelectionMerge %32 None + OpBranchConditional %30 %31 %32 + %31 = OpLabel + %34 = OpSMod %6 %92 %28 + %35 = OpIEqual %17 %34 %9 + OpBranch %32 + %32 = OpLabel + %36 = OpPhi %17 %30 %21 %35 %31 + OpSelectionMerge %38 None + OpBranchConditional %36 %37 %38 + %37 = OpLabel + %50 = OpAccessChain %7 %47 %88 %92 + %51 = OpLoad %6 %50 + %52 = OpIAdd %6 %51 %28 + %53 = OpAccessChain %7 %44 %88 %92 + OpStore %53 %52 + OpBranch %38 + %38 = OpLabel + OpBranch %23 + %23 = OpLabel + %56 = OpIAdd %6 %92 %55 + OpStore %19 %56 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %58 = OpIAdd %6 %88 %55 + OpStore %8 %58 + OpBranch %10 + %12 = OpLabel + OpStore %59 %9 + OpBranch %60 + %60 = OpLabel + %89 = OpPhi %6 %9 %12 %87 %63 + OpLoopMerge %62 %63 None + OpBranch %64 + %64 = OpLabel + %66 = OpSLessThan %17 %89 %16 + OpBranchConditional %66 %61 %62 + %61 = OpLabel + OpStore %67 %9 + OpBranch %68 + %68 = OpLabel + %90 = OpPhi %6 %9 %61 %85 %71 + OpLoopMerge %70 %71 None + OpBranch %72 + %72 = OpLabel + %74 = OpSLessThan %17 %90 %16 + OpBranchConditional %74 %69 %70 + %69 = OpLabel + %80 = OpAccessChain %7 %44 %89 %90 + %81 = OpLoad %6 %80 + %82 = OpIAdd %6 %81 %16 + %83 = OpAccessChain %7 %75 %89 %90 + OpStore %83 %82 + OpBranch %71 + %71 = OpLabel + %85 = OpIAdd %6 %90 %55 + OpStore %67 %85 + OpBranch %68 + %70 = OpLabel + OpBranch %63 + %63 = OpLabel + %87 = OpIAdd %6 %89 %55 + OpStore %59 %87 + OpBranch %60 + %62 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 4u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + auto loop_3 = loops[3]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_2, loop_3); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_3); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: OpPhi +CHECK-NEXT: OpSelectionMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK: [[PHI_2:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_2]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_2]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: OpPhi +CHECK-NEXT: OpSelectionMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Outer would have been illegal to fuse, but since written + // like this, inner loop fusion is legal. + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + c[i][j] = a[i][j] + 2; + } + for (int j = 0; j < 10; j++) { + b[i][j] = c[i+1][j] + 10; + } + } +} + +*/ +TEST_F(FusionLegalTest, InnerWithExistingDependenceOnOuter) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %32 "c" + OpName %35 "a" + OpName %46 "j" + OpName %54 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypeArray %29 %28 + %31 = OpTypePointer Function %30 + %40 = OpConstant %6 2 + %44 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %31 Function + %35 = OpVariable %31 Function + %46 = OpVariable %7 Function + %54 = OpVariable %31 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %68 = OpPhi %6 %9 %5 %67 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %68 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %69 = OpPhi %6 %9 %11 %45 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %69 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %38 = OpAccessChain %7 %35 %68 %69 + %39 = OpLoad %6 %38 + %41 = OpIAdd %6 %39 %40 + %42 = OpAccessChain %7 %32 %68 %69 + OpStore %42 %41 + OpBranch %23 + %23 = OpLabel + %45 = OpIAdd %6 %69 %44 + OpStore %19 %45 + OpBranch %20 + %22 = OpLabel + OpStore %46 %9 + OpBranch %47 + %47 = OpLabel + %70 = OpPhi %6 %9 %22 %65 %50 + OpLoopMerge %49 %50 None + OpBranch %51 + %51 = OpLabel + %53 = OpSLessThan %17 %70 %16 + OpBranchConditional %53 %48 %49 + %48 = OpLabel + %58 = OpIAdd %6 %68 %44 + %60 = OpAccessChain %7 %32 %58 %70 + %61 = OpLoad %6 %60 + %62 = OpIAdd %6 %61 %16 + %63 = OpAccessChain %7 %54 %68 %70 + OpStore %63 %62 + OpBranch %50 + %50 = OpLabel + %65 = OpIAdd %6 %70 %44 + OpStore %46 %65 + OpBranch %47 + %49 = OpLabel + OpBranch %13 + %13 = OpLabel + %67 = OpIAdd %6 %68 %44 + OpStore %8 %67 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI_0]] {{%\w+}} +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // One dimensional arrays. Legal, outer dist 0, inner independent. + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + c[i] = a[j] + 2; + } + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + b[j] = c[i] + 10; + } + } +} + +*/ +TEST_F(FusionLegalTest, OuterAndInnerLoopOneDimArrays) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %31 "c" + OpName %33 "a" + OpName %45 "i" + OpName %53 "j" + OpName %61 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypePointer Function %29 + %37 = OpConstant %6 2 + %41 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %31 = OpVariable %30 Function + %33 = OpVariable %30 Function + %45 = OpVariable %7 Function + %53 = OpVariable %7 Function + %61 = OpVariable %30 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %72 = OpPhi %6 %9 %5 %44 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %72 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %76 = OpPhi %6 %9 %11 %42 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %76 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %35 = OpAccessChain %7 %33 %76 + %36 = OpLoad %6 %35 + %38 = OpIAdd %6 %36 %37 + %39 = OpAccessChain %7 %31 %72 + OpStore %39 %38 + OpBranch %23 + %23 = OpLabel + %42 = OpIAdd %6 %76 %41 + OpStore %19 %42 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %44 = OpIAdd %6 %72 %41 + OpStore %8 %44 + OpBranch %10 + %12 = OpLabel + OpStore %45 %9 + OpBranch %46 + %46 = OpLabel + %73 = OpPhi %6 %9 %12 %71 %49 + OpLoopMerge %48 %49 None + OpBranch %50 + %50 = OpLabel + %52 = OpSLessThan %17 %73 %16 + OpBranchConditional %52 %47 %48 + %47 = OpLabel + OpStore %53 %9 + OpBranch %54 + %54 = OpLabel + %74 = OpPhi %6 %9 %47 %69 %57 + OpLoopMerge %56 %57 None + OpBranch %58 + %58 = OpLabel + %60 = OpSLessThan %17 %74 %16 + OpBranchConditional %60 %55 %56 + %55 = OpLabel + %64 = OpAccessChain %7 %31 %73 + %65 = OpLoad %6 %64 + %66 = OpIAdd %6 %65 %16 + %67 = OpAccessChain %7 %61 %74 + OpStore %67 %66 + OpBranch %57 + %57 = OpLabel + %69 = OpIAdd %6 %74 %41 + OpStore %53 %69 + OpBranch %54 + %56 = OpLabel + OpBranch %49 + %49 = OpLabel + %71 = OpIAdd %6 %73 %41 + OpStore %45 %71 + OpBranch %46 + %48 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 4u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + auto loop_3 = loops[3]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_2, loop_3); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK: [[PHI_2:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_2]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal, creates a loop-carried dependence, but has negative distance + for (int i = 0; i < 10; i++) { + c[i] = a[i+1] + 1; + } + for (int i = 0; i < 10; i++) { + a[i] = c[i] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, NegativeDistanceCreatedWAR) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "c" + OpName %25 "a" + OpName %35 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %27 = OpConstant %6 1 + %47 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %35 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %52 = OpPhi %6 %9 %5 %34 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %52 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %28 = OpIAdd %6 %52 %27 + %29 = OpAccessChain %7 %25 %28 + %30 = OpLoad %6 %29 + %31 = OpIAdd %6 %30 %27 + %32 = OpAccessChain %7 %23 %52 + OpStore %32 %31 + OpBranch %13 + %13 = OpLabel + %34 = OpIAdd %6 %52 %27 + OpStore %8 %34 + OpBranch %10 + %12 = OpLabel + OpStore %35 %9 + OpBranch %36 + %36 = OpLabel + %53 = OpPhi %6 %9 %12 %51 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %17 %53 %16 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %45 = OpAccessChain %7 %23 %53 + %46 = OpLoad %6 %45 + %48 = OpIAdd %6 %46 %47 + %49 = OpAccessChain %7 %25 %53 + OpStore %49 %48 + OpBranch %39 + %39 = OpLabel + %51 = OpIAdd %6 %53 %27 + OpStore %35 %51 + OpBranch %36 + %38 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK-NEXT: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + auto& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal, creates a loop-carried dependence, but has negative distance + for (int i = 0; i < 10; i++) { + a[i+1] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + a[i] = c[i+1] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, NegativeDistanceCreatedWAW) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %27 "b" + OpName %35 "i" + OpName %44 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %25 = OpConstant %6 1 + %49 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %27 = OpVariable %22 Function + %35 = OpVariable %7 Function + %44 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %54 = OpPhi %6 %9 %5 %34 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %54 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpIAdd %6 %54 %25 + %29 = OpAccessChain %7 %27 %54 + %30 = OpLoad %6 %29 + %31 = OpIAdd %6 %30 %25 + %32 = OpAccessChain %7 %23 %26 + OpStore %32 %31 + OpBranch %13 + %13 = OpLabel + %34 = OpIAdd %6 %54 %25 + OpStore %8 %34 + OpBranch %10 + %12 = OpLabel + OpStore %35 %9 + OpBranch %36 + %36 = OpLabel + %55 = OpPhi %6 %9 %12 %53 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %17 %55 %16 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %46 = OpIAdd %6 %55 %25 + %47 = OpAccessChain %7 %44 %46 + %48 = OpLoad %6 %47 + %50 = OpIAdd %6 %48 %49 + %51 = OpAccessChain %7 %23 %55 + OpStore %51 %50 + OpBranch %39 + %39 = OpLabel + %53 = OpIAdd %6 %55 %25 + OpStore %35 %53 + OpBranch %36 + %38 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpStore +CHECK-NOT: OpPhi +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK-NEXT: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal, no loop-carried dependence + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + a[i] = c[i+1] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, NoLoopCarriedDependencesWAW) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %43 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %48 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %43 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %53 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %53 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %53 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %53 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %53 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %54 = OpPhi %6 %9 %12 %52 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %54 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpIAdd %6 %54 %29 + %46 = OpAccessChain %7 %43 %45 + %47 = OpLoad %6 %46 + %49 = OpIAdd %6 %47 %48 + %50 = OpAccessChain %7 %23 %54 + OpStore %50 %49 + OpBranch %38 + %38 = OpLabel + %52 = OpIAdd %6 %54 %29 + OpStore %34 %52 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK-NEXT: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Legal outer. Continue and break are fine if nested in inner loops + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + if (j % 2 == 0) { + c[i][j] = a[i][j] + 2; + } else { + continue; + } + } + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + if (j % 2 == 0) { + b[i][j] = c[i][j] + 10; + } else { + break; + } + } + } +} + +*/ +TEST_F(FusionLegalTest, OuterloopWithBreakContinueInInner) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %38 "c" + OpName %41 "a" + OpName %55 "i" + OpName %63 "j" + OpName %76 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %28 = OpConstant %6 2 + %33 = OpTypeInt 32 0 + %34 = OpConstant %33 10 + %35 = OpTypeArray %6 %34 + %36 = OpTypeArray %35 %34 + %37 = OpTypePointer Function %36 + %51 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %38 = OpVariable %37 Function + %41 = OpVariable %37 Function + %55 = OpVariable %7 Function + %63 = OpVariable %7 Function + %76 = OpVariable %37 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %91 = OpPhi %6 %9 %5 %54 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %91 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %96 = OpPhi %6 %9 %11 %52 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %96 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %29 = OpSMod %6 %96 %28 + %30 = OpIEqual %17 %29 %9 + OpSelectionMerge %32 None + OpBranchConditional %30 %31 %48 + %31 = OpLabel + %44 = OpAccessChain %7 %41 %91 %96 + %45 = OpLoad %6 %44 + %46 = OpIAdd %6 %45 %28 + %47 = OpAccessChain %7 %38 %91 %96 + OpStore %47 %46 + OpBranch %32 + %48 = OpLabel + OpBranch %23 + %32 = OpLabel + OpBranch %23 + %23 = OpLabel + %52 = OpIAdd %6 %96 %51 + OpStore %19 %52 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %54 = OpIAdd %6 %91 %51 + OpStore %8 %54 + OpBranch %10 + %12 = OpLabel + OpStore %55 %9 + OpBranch %56 + %56 = OpLabel + %92 = OpPhi %6 %9 %12 %90 %59 + OpLoopMerge %58 %59 None + OpBranch %60 + %60 = OpLabel + %62 = OpSLessThan %17 %92 %16 + OpBranchConditional %62 %57 %58 + %57 = OpLabel + OpStore %63 %9 + OpBranch %64 + %64 = OpLabel + %93 = OpPhi %6 %9 %57 %88 %67 + OpLoopMerge %66 %67 None + OpBranch %68 + %68 = OpLabel + %70 = OpSLessThan %17 %93 %16 + OpBranchConditional %70 %65 %66 + %65 = OpLabel + %72 = OpSMod %6 %93 %28 + %73 = OpIEqual %17 %72 %9 + OpSelectionMerge %75 None + OpBranchConditional %73 %74 %85 + %74 = OpLabel + %81 = OpAccessChain %7 %38 %92 %93 + %82 = OpLoad %6 %81 + %83 = OpIAdd %6 %82 %16 + %84 = OpAccessChain %7 %76 %92 %93 + OpStore %84 %83 + OpBranch %75 + %85 = OpLabel + OpBranch %66 + %75 = OpLabel + OpBranch %67 + %67 = OpLabel + %88 = OpIAdd %6 %93 %51 + OpStore %63 %88 + OpBranch %64 + %66 = OpLabel + OpBranch %59 + %59 = OpLabel + %90 = OpIAdd %6 %92 %51 + OpStore %55 %90 + OpBranch %56 + %58 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 4u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[2]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[1], loops[2]); + EXPECT_FALSE(fusion.AreCompatible()); + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK: [[PHI_2:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_2]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_2]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// j loop preheader removed manually +#version 440 core +void main() { + int[10] a; + int[10] b; + int i = 0; + int j = 0; + // No loop-carried dependences, legal + for (; i < 10; i++) { + a[i] = a[i]*2; + } + for (; j < 10; j++) { + b[j] = a[j]+2; + } +} + +*/ +TEST_F(FusionLegalTest, DifferentArraysInLoopsNoPreheader) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %10 "j" + OpName %24 "a" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 10 + %22 = OpTypeArray %6 %21 + %23 = OpTypePointer Function %22 + %29 = OpConstant %6 2 + %33 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %24 = OpVariable %23 Function + %42 = OpVariable %23 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %51 = OpPhi %6 %9 %5 %34 %14 + OpLoopMerge %35 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %51 %17 + OpBranchConditional %19 %12 %35 + %12 = OpLabel + %27 = OpAccessChain %7 %24 %51 + %28 = OpLoad %6 %27 + %30 = OpIMul %6 %28 %29 + %31 = OpAccessChain %7 %24 %51 + OpStore %31 %30 + OpBranch %14 + %14 = OpLabel + %34 = OpIAdd %6 %51 %33 + OpStore %8 %34 + OpBranch %11 + %35 = OpLabel + %52 = OpPhi %6 %9 %15 %50 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %18 %52 %17 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %24 %52 + %46 = OpLoad %6 %45 + %47 = OpIAdd %6 %46 %29 + %48 = OpAccessChain %7 %42 %52 + OpStore %48 %47 + OpBranch %38 + %38 = OpLabel + %50 = OpIAdd %6 %52 %33 + OpStore %10 %50 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + { + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); + } + + ld.CreatePreHeaderBlocksIfMissing(); + + { + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// j & k loop preheaders removed manually +#version 440 core +void main() { + int[10] a; + int[10] b; + int i = 0; + int j = 0; + int k = 0; + // No loop-carried dependences, legal + for (; i < 10; i++) { + a[i] = a[i]*2; + } + for (; j < 10; j++) { + b[j] = a[j]+2; + } + for (; k < 10; k++) { + a[k] = a[k]*2; + } +} + +*/ +TEST_F(FusionLegalTest, AdjacentLoopsNoPreheaders) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %10 "j" + OpName %11 "k" + OpName %25 "a" + OpName %43 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 10 + %23 = OpTypeArray %6 %22 + %24 = OpTypePointer Function %23 + %30 = OpConstant %6 2 + %34 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %11 = OpVariable %7 Function + %25 = OpVariable %24 Function + %43 = OpVariable %24 Function + OpStore %8 %9 + OpStore %10 %9 + OpStore %11 %9 + OpBranch %12 + %12 = OpLabel + %67 = OpPhi %6 %9 %5 %35 %15 + OpLoopMerge %36 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %67 %18 + OpBranchConditional %20 %13 %36 + %13 = OpLabel + %28 = OpAccessChain %7 %25 %67 + %29 = OpLoad %6 %28 + %31 = OpIMul %6 %29 %30 + %32 = OpAccessChain %7 %25 %67 + OpStore %32 %31 + OpBranch %15 + %15 = OpLabel + %35 = OpIAdd %6 %67 %34 + OpStore %8 %35 + OpBranch %12 + %36 = OpLabel + %68 = OpPhi %6 %9 %16 %51 %39 + OpLoopMerge %52 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %19 %68 %18 + OpBranchConditional %42 %37 %52 + %37 = OpLabel + %46 = OpAccessChain %7 %25 %68 + %47 = OpLoad %6 %46 + %48 = OpIAdd %6 %47 %30 + %49 = OpAccessChain %7 %43 %68 + OpStore %49 %48 + OpBranch %39 + %39 = OpLabel + %51 = OpIAdd %6 %68 %34 + OpStore %10 %51 + OpBranch %36 + %52 = OpLabel + %70 = OpPhi %6 %9 %40 %66 %55 + OpLoopMerge %54 %55 None + OpBranch %56 + %56 = OpLabel + %58 = OpSLessThan %19 %70 %18 + OpBranchConditional %58 %53 %54 + %53 = OpLabel + %61 = OpAccessChain %7 %25 %70 + %62 = OpLoad %6 %61 + %63 = OpIMul %6 %62 %30 + %64 = OpAccessChain %7 %25 %70 + OpStore %64 %63 + OpBranch %55 + %55 = OpLabel + %66 = OpIAdd %6 %70 %34 + OpStore %11 %66 + OpBranch %52 + %54 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + { + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); + } + + ld.CreatePreHeaderBlocksIfMissing(); + + { + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpStore [[STORE_1]] +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_2]] +CHECK: [[STORE_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_2]] + )"; + + Match(checks, context.get()); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_2]] +CHECK: [[STORE_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_2]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum_0 = 0; + int sum_1 = 0; + + // No loop-carried dependences, legal + for (int i = 0; i < 10; i++) { + sum_0 += a[i]; + } + for (int j = 0; j < 10; j++) { + sum_1 += b[j]; + } + + int total = sum_0 + sum_1; +} + +*/ +TEST_F(FusionLegalTest, IndependentReductions) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum_0" + OpName %10 "sum_1" + OpName %11 "i" + OpName %25 "a" + OpName %34 "j" + OpName %42 "b" + OpName %50 "total" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 10 + %23 = OpTypeArray %6 %22 + %24 = OpTypePointer Function %23 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %11 = OpVariable %7 Function + %25 = OpVariable %24 Function + %34 = OpVariable %7 Function + %42 = OpVariable %24 Function + %50 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpStore %11 %9 + OpBranch %12 + %12 = OpLabel + %57 = OpPhi %6 %9 %5 %30 %15 + %54 = OpPhi %6 %9 %5 %33 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %54 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + %27 = OpAccessChain %7 %25 %54 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %57 %28 + OpStore %8 %30 + OpBranch %15 + %15 = OpLabel + %33 = OpIAdd %6 %54 %32 + OpStore %11 %33 + OpBranch %12 + %14 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %58 = OpPhi %6 %9 %14 %47 %38 + %55 = OpPhi %6 %9 %14 %49 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %19 %55 %18 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %44 = OpAccessChain %7 %42 %55 + %45 = OpLoad %6 %44 + %47 = OpIAdd %6 %58 %45 + OpStore %10 %47 + OpBranch %38 + %38 = OpLabel + %49 = OpIAdd %6 %55 %32 + OpStore %34 %49 + OpBranch %35 + %37 = OpLabel + %53 = OpIAdd %6 %57 %58 + OpStore %50 %53 + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[SUM_0:%\w+]] = OpPhi +CHECK-NEXT: [[SUM_1:%\w+]] = OpPhi +CHECK-NEXT: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_0:%\w+]] = OpLoad {{%\w+}} [[LOAD_0]] +CHECK-NEXT: [[ADD_RES_0:%\w+]] = OpIAdd {{%\w+}} [[SUM_0]] [[LOAD_RES_0]] +CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_1:%\w+]] = OpLoad {{%\w+}} [[LOAD_1]] +CHECK-NEXT: [[ADD_RES_1:%\w+]] = OpIAdd {{%\w+}} [[SUM_1]] [[LOAD_RES_1]] +CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum_0 = 0; + int sum_1 = 0; + + // No loop-carried dependences, legal + for (int i = 0; i < 10; i++) { + sum_0 += a[i]; + } + for (int j = 0; j < 10; j++) { + sum_1 += b[j]; + } + + int total = sum_0 + sum_1; +} + +*/ +TEST_F(FusionLegalTest, IndependentReductionsOneLCSSA) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum_0" + OpName %10 "sum_1" + OpName %11 "i" + OpName %25 "a" + OpName %34 "j" + OpName %42 "b" + OpName %50 "total" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 10 + %23 = OpTypeArray %6 %22 + %24 = OpTypePointer Function %23 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %11 = OpVariable %7 Function + %25 = OpVariable %24 Function + %34 = OpVariable %7 Function + %42 = OpVariable %24 Function + %50 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpStore %11 %9 + OpBranch %12 + %12 = OpLabel + %57 = OpPhi %6 %9 %5 %30 %15 + %54 = OpPhi %6 %9 %5 %33 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %54 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + %27 = OpAccessChain %7 %25 %54 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %57 %28 + OpStore %8 %30 + OpBranch %15 + %15 = OpLabel + %33 = OpIAdd %6 %54 %32 + OpStore %11 %33 + OpBranch %12 + %14 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %58 = OpPhi %6 %9 %14 %47 %38 + %55 = OpPhi %6 %9 %14 %49 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %19 %55 %18 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %44 = OpAccessChain %7 %42 %55 + %45 = OpLoad %6 %44 + %47 = OpIAdd %6 %58 %45 + OpStore %10 %47 + OpBranch %38 + %38 = OpLabel + %49 = OpIAdd %6 %55 %32 + OpStore %34 %49 + OpBranch %35 + %37 = OpLabel + %53 = OpIAdd %6 %57 %58 + OpStore %50 %53 + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopUtils utils_0(context.get(), loops[0]); + utils_0.MakeLoopClosedSSA(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[SUM_0:%\w+]] = OpPhi +CHECK-NEXT: [[SUM_1:%\w+]] = OpPhi +CHECK-NEXT: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_0:%\w+]] = OpLoad {{%\w+}} [[LOAD_0]] +CHECK-NEXT: [[ADD_RES_0:%\w+]] = OpIAdd {{%\w+}} [[SUM_0]] [[LOAD_RES_0]] +CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_1:%\w+]] = OpLoad {{%\w+}} [[LOAD_1]] +CHECK-NEXT: [[ADD_RES_1:%\w+]] = OpIAdd {{%\w+}} [[SUM_1]] [[LOAD_RES_1]] +CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum_0 = 0; + int sum_1 = 0; + + // No loop-carried dependences, legal + for (int i = 0; i < 10; i++) { + sum_0 += a[i]; + } + for (int j = 0; j < 10; j++) { + sum_1 += b[j]; + } + + int total = sum_0 + sum_1; +} + +*/ +TEST_F(FusionLegalTest, IndependentReductionsBothLCSSA) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum_0" + OpName %10 "sum_1" + OpName %11 "i" + OpName %25 "a" + OpName %34 "j" + OpName %42 "b" + OpName %50 "total" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 10 + %23 = OpTypeArray %6 %22 + %24 = OpTypePointer Function %23 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %11 = OpVariable %7 Function + %25 = OpVariable %24 Function + %34 = OpVariable %7 Function + %42 = OpVariable %24 Function + %50 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpStore %11 %9 + OpBranch %12 + %12 = OpLabel + %57 = OpPhi %6 %9 %5 %30 %15 + %54 = OpPhi %6 %9 %5 %33 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %54 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + %27 = OpAccessChain %7 %25 %54 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %57 %28 + OpStore %8 %30 + OpBranch %15 + %15 = OpLabel + %33 = OpIAdd %6 %54 %32 + OpStore %11 %33 + OpBranch %12 + %14 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %58 = OpPhi %6 %9 %14 %47 %38 + %55 = OpPhi %6 %9 %14 %49 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %19 %55 %18 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %44 = OpAccessChain %7 %42 %55 + %45 = OpLoad %6 %44 + %47 = OpIAdd %6 %58 %45 + OpStore %10 %47 + OpBranch %38 + %38 = OpLabel + %49 = OpIAdd %6 %55 %32 + OpStore %34 %49 + OpBranch %35 + %37 = OpLabel + %53 = OpIAdd %6 %57 %58 + OpStore %50 %53 + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopUtils utils_0(context.get(), loops[0]); + utils_0.MakeLoopClosedSSA(); + LoopUtils utils_1(context.get(), loops[1]); + utils_1.MakeLoopClosedSSA(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[SUM_0:%\w+]] = OpPhi +CHECK-NEXT: [[SUM_1:%\w+]] = OpPhi +CHECK-NEXT: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_0:%\w+]] = OpLoad {{%\w+}} [[LOAD_0]] +CHECK-NEXT: [[ADD_RES_0:%\w+]] = OpIAdd {{%\w+}} [[SUM_0]] [[LOAD_RES_0]] +CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_1:%\w+]] = OpLoad {{%\w+}} [[LOAD_1]] +CHECK-NEXT: [[ADD_RES_1:%\w+]] = OpIAdd {{%\w+}} [[SUM_1]] [[LOAD_RES_1]] +CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum_0 = 0; + + // No loop-carried dependences, legal + for (int i = 0; i < 10; i++) { + sum_0 += a[i]; + } + for (int j = 0; j < 10; j++) { + a[j] = b[j]; + } +} + +*/ +TEST_F(FusionLegalTest, LoadStoreReductionAndNonLoopCarriedDependence) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum_0" + OpName %10 "i" + OpName %24 "a" + OpName %33 "j" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 10 + %22 = OpTypeArray %6 %21 + %23 = OpTypePointer Function %22 + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %24 = OpVariable %23 Function + %33 = OpVariable %7 Function + %42 = OpVariable %23 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %51 = OpPhi %6 %9 %5 %29 %14 + %49 = OpPhi %6 %9 %5 %32 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %49 %17 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + %26 = OpAccessChain %7 %24 %49 + %27 = OpLoad %6 %26 + %29 = OpIAdd %6 %51 %27 + OpStore %8 %29 + OpBranch %14 + %14 = OpLabel + %32 = OpIAdd %6 %49 %31 + OpStore %10 %32 + OpBranch %11 + %13 = OpLabel + OpStore %33 %9 + OpBranch %34 + %34 = OpLabel + %50 = OpPhi %6 %9 %13 %48 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %40 = OpSLessThan %18 %50 %17 + OpBranchConditional %40 %35 %36 + %35 = OpLabel + %44 = OpAccessChain %7 %42 %50 + %45 = OpLoad %6 %44 + %46 = OpAccessChain %7 %24 %50 + OpStore %46 %45 + OpBranch %37 + %37 = OpLabel + %48 = OpIAdd %6 %50 %31 + OpStore %33 %48 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + // TODO: Loop descriptor doesn't return induction variables but all OpPhi + // in the header and LoopDependenceAnalysis falls over. + // EXPECT_TRUE(fusion.IsLegal()); + + // fusion.Fuse(); + } + + { + // LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + // EXPECT_EQ(ld.NumLoops(), 1u); + + // std::string checks = R"( + // CHECK: [[SUM_0:%\w+]] = OpPhi + // CHECK-NEXT: [[PHI:%\w+]] = OpPhi + // CHECK-NEXT: OpLoopMerge + // CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] + // CHECK-NEXT: [[LOAD_RES_0:%\w+]] = OpLoad {{%\w+}} [[LOAD_0]] + // CHECK-NEXT: [[ADD_RES_0:%\w+]] = OpIAdd {{%\w+}} [[SUM_0]] [[LOAD_RES_0]] + // CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_0]] + // CHECK-NOT: OpPhi + // CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] + // CHECK-NEXT: [[LOAD_RES_1:%\w+]] = OpLoad {{%\w+}} [[LOAD_1]] + // CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] + // CHECK-NEXT: OpStore [[STORE_1]] [[LOAD_RES_1]] + // )"; + + // Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +int x; +void main() { + int[10] a; + int[10] b; + + // Legal. + for (int i = 0; i < 10; i++) { + x += a[i]; + } + for (int j = 0; j < 10; j++) { + b[j] = b[j]+1; + } +} + +*/ +TEST_F(FusionLegalTest, ReductionAndNonLoopCarriedDependence) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %20 "x" + OpName %25 "a" + OpName %34 "j" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypePointer Private %6 + %20 = OpVariable %19 Private + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 10 + %23 = OpTypeArray %6 %22 + %24 = OpTypePointer Function %23 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %25 = OpVariable %24 Function + %34 = OpVariable %7 Function + %42 = OpVariable %24 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %51 + %28 = OpLoad %6 %27 + %29 = OpLoad %6 %20 + %30 = OpIAdd %6 %29 %28 + OpStore %20 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %51 %32 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %52 = OpPhi %6 %9 %12 %50 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %52 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %42 %52 + %46 = OpLoad %6 %45 + %47 = OpIAdd %6 %46 %32 + %48 = OpAccessChain %7 %42 %52 + OpStore %48 %47 + OpBranch %38 + %38 = OpLabel + %50 = OpIAdd %6 %52 %32 + OpStore %34 %50 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: OpName [[X:%\w+]] "x" +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_0:%\w+]] = OpLoad {{%\w+}} [[LOAD_0]] +CHECK-NEXT: [[X_LOAD:%\w+]] = OpLoad {{%\w+}} [[X]] +CHECK-NEXT: [[ADD_RES_0:%\w+]] = OpIAdd {{%\w+}} [[X_LOAD]] [[LOAD_RES_0]] +CHECK-NEXT: OpStore [[X]] [[ADD_RES_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: {{%\w+}} = OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +struct TestStruct { + int[10] a; + int b; +}; + +void main() { + TestStruct test_0; + TestStruct test_1; + TestStruct test_2; + + test_1.b = 2; + + for (int i = 0; i < 10; i++) { + test_0.a[i] = i; + } + for (int j = 0; j < 10; j++) { + test_2 = test_1; + } +} + +*/ +TEST_F(FusionLegalTest, ArrayInStruct) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %10 "TestStruct" + OpMemberName %10 0 "a" + OpMemberName %10 1 "b" + OpName %12 "test_1" + OpName %17 "i" + OpName %28 "test_0" + OpName %34 "j" + OpName %42 "test_2" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypeInt 32 0 + %8 = OpConstant %7 10 + %9 = OpTypeArray %6 %8 + %10 = OpTypeStruct %9 %6 + %11 = OpTypePointer Function %10 + %13 = OpConstant %6 1 + %14 = OpConstant %6 2 + %15 = OpTypePointer Function %6 + %18 = OpConstant %6 0 + %25 = OpConstant %6 10 + %26 = OpTypeBool + %4 = OpFunction %2 None %3 + %5 = OpLabel + %12 = OpVariable %11 Function + %17 = OpVariable %15 Function + %28 = OpVariable %11 Function + %34 = OpVariable %15 Function + %42 = OpVariable %11 Function + %16 = OpAccessChain %15 %12 %13 + OpStore %16 %14 + OpStore %17 %18 + OpBranch %19 + %19 = OpLabel + %46 = OpPhi %6 %18 %5 %33 %22 + OpLoopMerge %21 %22 None + OpBranch %23 + %23 = OpLabel + %27 = OpSLessThan %26 %46 %25 + OpBranchConditional %27 %20 %21 + %20 = OpLabel + %31 = OpAccessChain %15 %28 %18 %46 + OpStore %31 %46 + OpBranch %22 + %22 = OpLabel + %33 = OpIAdd %6 %46 %13 + OpStore %17 %33 + OpBranch %19 + %21 = OpLabel + OpStore %34 %18 + OpBranch %35 + %35 = OpLabel + %47 = OpPhi %6 %18 %21 %45 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %26 %47 %25 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %43 = OpLoad %10 %12 + OpStore %42 %43 + OpBranch %38 + %38 = OpLabel + %45 = OpIAdd %6 %47 %13 + OpStore %34 %45 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + // clang-format off + std::string checks = R"( +CHECK: OpName [[TEST_1:%\w+]] "test_1" +CHECK: OpName [[TEST_0:%\w+]] "test_0" +CHECK: OpName [[TEST_2:%\w+]] "test_2" +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[TEST_0_STORE:%\w+]] = OpAccessChain {{%\w+}} [[TEST_0]] {{%\w+}} {{%\w+}} +CHECK-NEXT: OpStore [[TEST_0_STORE]] [[PHI]] +CHECK-NOT: OpPhi +CHECK: [[TEST_1_LOAD:%\w+]] = OpLoad {{%\w+}} [[TEST_1]] +CHECK: OpStore [[TEST_2]] [[TEST_1_LOAD]] + )"; + // clang-format on + + Match(checks, context.get()); + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/fusion_pass.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/fusion_pass.cpp new file mode 100644 index 000000000..857ada939 --- /dev/null +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/fusion_pass.cpp @@ -0,0 +1,724 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/opt/pass_fixture.h" + +#ifdef SPIRV_EFFCEE +#include "effcee/effcee.h" +#endif + +namespace spvtools { +namespace opt { +namespace { + +using FusionPassTest = PassTest<::testing::Test>; + +#ifdef SPIRV_EFFCEE + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + for (int i = 0; i < 10; i++) { + a[i] = a[i]*2; + } + for (int i = 0; i < 10; i++) { + b[i] = a[i]+2; + } +} + +*/ +TEST_F(FusionPassTest, SimpleFusion) { + const std::string text = R"( +; CHECK: OpPhi +; CHECK: OpLoad +; CHECK: OpStore +; CHECK-NOT: OpPhi +; CHECK: OpLoad +; CHECK: OpStore + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %34 "i" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %28 = OpConstant %6 2 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpAccessChain %7 %23 %51 + %27 = OpLoad %6 %26 + %29 = OpIMul %6 %27 %28 + %30 = OpAccessChain %7 %23 %51 + OpStore %30 %29 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %51 %32 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %52 = OpPhi %6 %9 %12 %50 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %52 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %23 %52 + %46 = OpLoad %6 %45 + %47 = OpIAdd %6 %46 %28 + %48 = OpAccessChain %7 %42 %52 + OpStore %48 %47 + OpBranch %38 + %38 = OpLabel + %50 = OpIAdd %6 %52 %32 + OpStore %34 %50 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true, 20); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i] + 2; + } + for (int i = 0; i < 10; i++) { + b[i] = c[i] + 10; + } +} + +*/ +TEST_F(FusionPassTest, ThreeLoopsFused) { + const std::string text = R"( +; CHECK: OpPhi +; CHECK: OpLoad +; CHECK: OpStore +; CHECK-NOT: OpPhi +; CHECK: OpLoad +; CHECK: OpStore +; CHECK-NOT: OpPhi +; CHECK: OpLoad +; CHECK: OpStore + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %42 "c" + OpName %52 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %47 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + %52 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %68 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %68 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %68 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %68 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %68 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %69 = OpPhi %6 %9 %12 %51 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %69 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %23 %69 + %46 = OpLoad %6 %45 + %48 = OpIAdd %6 %46 %47 + %49 = OpAccessChain %7 %42 %69 + OpStore %49 %48 + OpBranch %38 + %38 = OpLabel + %51 = OpIAdd %6 %69 %29 + OpStore %34 %51 + OpBranch %35 + %37 = OpLabel + OpStore %52 %9 + OpBranch %53 + %53 = OpLabel + %70 = OpPhi %6 %9 %37 %67 %56 + OpLoopMerge %55 %56 None + OpBranch %57 + %57 = OpLabel + %59 = OpSLessThan %17 %70 %16 + OpBranchConditional %59 %54 %55 + %54 = OpLabel + %62 = OpAccessChain %7 %42 %70 + %63 = OpLoad %6 %62 + %64 = OpIAdd %6 %63 %16 + %65 = OpAccessChain %7 %25 %70 + OpStore %65 %64 + OpBranch %56 + %56 = OpLabel + %67 = OpIAdd %6 %70 %29 + OpStore %52 %67 + OpBranch %53 + %55 = OpLabel + OpReturn + OpFunctionEnd + + )"; + + SinglePassRunAndMatch(text, true, 20); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Legal both + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + c[i][j] = a[i][j] + 2; + } + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + b[i][j] = c[i][j] + 10; + } + } +} + +*/ +TEST_F(FusionPassTest, NestedLoopsFused) { + const std::string text = R"( +; CHECK: OpPhi +; CHECK: OpPhi +; CHECK: OpLoad +; CHECK: OpStore +; CHECK-NOT: OpPhi +; CHECK: OpLoad +; CHECK: OpStore + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %32 "c" + OpName %35 "a" + OpName %48 "i" + OpName %56 "j" + OpName %64 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypeArray %29 %28 + %31 = OpTypePointer Function %30 + %40 = OpConstant %6 2 + %44 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %31 Function + %35 = OpVariable %31 Function + %48 = OpVariable %7 Function + %56 = OpVariable %7 Function + %64 = OpVariable %31 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %77 = OpPhi %6 %9 %5 %47 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %77 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %81 = OpPhi %6 %9 %11 %45 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %81 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %38 = OpAccessChain %7 %35 %77 %81 + %39 = OpLoad %6 %38 + %41 = OpIAdd %6 %39 %40 + %42 = OpAccessChain %7 %32 %77 %81 + OpStore %42 %41 + OpBranch %23 + %23 = OpLabel + %45 = OpIAdd %6 %81 %44 + OpStore %19 %45 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %47 = OpIAdd %6 %77 %44 + OpStore %8 %47 + OpBranch %10 + %12 = OpLabel + OpStore %48 %9 + OpBranch %49 + %49 = OpLabel + %78 = OpPhi %6 %9 %12 %76 %52 + OpLoopMerge %51 %52 None + OpBranch %53 + %53 = OpLabel + %55 = OpSLessThan %17 %78 %16 + OpBranchConditional %55 %50 %51 + %50 = OpLabel + OpStore %56 %9 + OpBranch %57 + %57 = OpLabel + %79 = OpPhi %6 %9 %50 %74 %60 + OpLoopMerge %59 %60 None + OpBranch %61 + %61 = OpLabel + %63 = OpSLessThan %17 %79 %16 + OpBranchConditional %63 %58 %59 + %58 = OpLabel + %69 = OpAccessChain %7 %32 %78 %79 + %70 = OpLoad %6 %69 + %71 = OpIAdd %6 %70 %16 + %72 = OpAccessChain %7 %64 %78 %79 + OpStore %72 %71 + OpBranch %60 + %60 = OpLabel + %74 = OpIAdd %6 %79 %44 + OpStore %56 %74 + OpBranch %57 + %59 = OpLabel + OpBranch %52 + %52 = OpLabel + %76 = OpIAdd %6 %78 %44 + OpStore %48 %76 + OpBranch %49 + %51 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true, 20); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + // Can't fuse, different step + for (int i = 0; i < 10; i++) {} + for (int j = 0; j < 10; j=j+2) {} +} + +*/ +TEST_F(FusionPassTest, Incompatible) { + const std::string text = R"( +; CHECK: OpPhi +; CHECK-NEXT: OpLoopMerge +; CHECK: OpPhi +; CHECK-NEXT: OpLoopMerge + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %31 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %33 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %33 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %33 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %9 + OpBranch %23 + %23 = OpLabel + %34 = OpPhi %6 %9 %12 %32 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %17 %34 %16 + OpBranchConditional %29 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %32 = OpIAdd %6 %34 %31 + OpStore %22 %32 + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true, 20); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Illegal, loop-independent dependence will become a + // backward loop-carried antidependence + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i+1] + 2; + } +} + +*/ +TEST_F(FusionPassTest, Illegal) { + std::string text = R"( +; CHECK: OpPhi +; CHECK-NEXT: OpLoopMerge +; CHECK: OpLoad +; CHECK: OpStore +; CHECK: OpPhi +; CHECK-NEXT: OpLoopMerge +; CHECK: OpLoad +; CHECK: OpStore + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %42 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %48 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %53 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %53 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %53 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %53 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %53 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %54 = OpPhi %6 %9 %12 %52 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %54 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpIAdd %6 %54 %29 + %46 = OpAccessChain %7 %23 %45 + %47 = OpLoad %6 %46 + %49 = OpIAdd %6 %47 %48 + %50 = OpAccessChain %7 %42 %54 + OpStore %50 %49 + OpBranch %38 + %38 = OpLabel + %52 = OpIAdd %6 %54 %29 + OpStore %34 %52 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true, 20); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + for (int i = 0; i < 10; i++) { + a[i] = a[i]*2; + } + for (int i = 0; i < 10; i++) { + b[i] = a[i]+2; + } +} + +*/ +TEST_F(FusionPassTest, TooManyRegisters) { + const std::string text = R"( +; CHECK: OpPhi +; CHECK-NEXT: OpLoopMerge +; CHECK: OpLoad +; CHECK: OpStore +; CHECK: OpPhi +; CHECK-NEXT: OpLoopMerge +; CHECK: OpLoad +; CHECK: OpStore + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %34 "i" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %28 = OpConstant %6 2 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpAccessChain %7 %23 %51 + %27 = OpLoad %6 %26 + %29 = OpIMul %6 %27 %28 + %30 = OpAccessChain %7 %23 %51 + OpStore %30 %29 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %51 %32 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %52 = OpPhi %6 %9 %12 %50 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %52 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %23 %52 + %46 = OpLoad %6 %45 + %47 = OpIAdd %6 %46 %28 + %48 = OpAccessChain %7 %42 %52 + OpStore %48 %47 + OpBranch %38 + %38 = OpLabel + %50 = OpIAdd %6 %52 %32 + OpStore %34 %50 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true, 5); +} + +#endif + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_all_loop_types.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_all_loop_types.cpp index bb42fdd0c..27e0a0d91 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_all_loop_types.cpp +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_all_loop_types.cpp @@ -14,16 +14,15 @@ #include -#include - -#include "../pass_fixture.h" -#include "opt/licm_pass.h" +#include "gmock/gmock.h" +#include "source/opt/licm_pass.h" +#include "test/opt/pass_fixture.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -278,7 +277,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before_hoist, after_hoist, true); + SinglePassRunAndCheck(before_hoist, after_hoist, true); } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_double_nested_loops.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_double_nested_loops.cpp index 46bb562b2..ea1949658 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_double_nested_loops.cpp +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_double_nested_loops.cpp @@ -14,16 +14,15 @@ #include -#include - -#include "../pass_fixture.h" -#include "opt/licm_pass.h" +#include "gmock/gmock.h" +#include "source/opt/licm_pass.h" +#include "test/opt/pass_fixture.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -155,7 +154,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before_hoist, after_hoist, true); + SinglePassRunAndCheck(before_hoist, after_hoist, true); } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_from_independent_loops.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_from_independent_loops.cpp index 1c7bebaf1..abc79e37c 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_from_independent_loops.cpp +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_from_independent_loops.cpp @@ -14,16 +14,15 @@ #include -#include - -#include "../pass_fixture.h" -#include "opt/licm_pass.h" +#include "gmock/gmock.h" +#include "source/opt/licm_pass.h" +#include "test/opt/pass_fixture.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -194,7 +193,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before_hoist, after_hoist, true); + SinglePassRunAndCheck(before_hoist, after_hoist, true); } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_simple_case.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_simple_case.cpp index ce29c6302..e973d9d29 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_simple_case.cpp +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_simple_case.cpp @@ -14,16 +14,15 @@ #include -#include - -#include "../pass_fixture.h" -#include "opt/licm_pass.h" +#include "gmock/gmock.h" +#include "source/opt/licm_pass.h" +#include "test/opt/pass_fixture.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -119,7 +118,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before_hoist, after_hoist, true); + SinglePassRunAndCheck(before_hoist, after_hoist, true); } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_single_nested_loops.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_single_nested_loops.cpp index 563c48960..7fa1fb0a0 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_single_nested_loops.cpp +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_single_nested_loops.cpp @@ -14,14 +14,14 @@ #include -#include - -#include "../pass_fixture.h" -#include "opt/licm_pass.h" +#include "gmock/gmock.h" +#include "source/opt/licm_pass.h" +#include "test/opt/pass_fixture.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; using PassClassTest = PassTest<::testing::Test>; @@ -155,7 +155,9 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before_hoist, after_hoist, true); + SinglePassRunAndCheck(before_hoist, after_hoist, true); } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_without_preheader.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_without_preheader.cpp index fe60fc82d..9e8d02fac 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_without_preheader.cpp +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/hoist_without_preheader.cpp @@ -14,16 +14,15 @@ #include -#include - -#include "../pass_fixture.h" -#include "opt/licm_pass.h" +#include "gmock/gmock.h" +#include "source/opt/licm_pass.h" +#include "test/opt/pass_fixture.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -116,8 +115,10 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } #endif } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/lcssa.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/lcssa.cpp index 95d5a2a57..220772652 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/lcssa.cpp +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/lcssa.cpp @@ -12,28 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include #include #include +#include "gmock/gmock.h" +#include "source/opt/build_module.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_utils.h" +#include "source/opt/pass.h" +#include "test/opt//assembly_builder.h" +#include "test/opt/function_utils.h" + #ifdef SPIRV_EFFCEE #include "effcee/effcee.h" #endif -#include "../assembly_builder.h" -#include "../function_utils.h" - -#include "opt/build_module.h" -#include "opt/loop_descriptor.h" -#include "opt/loop_utils.h" -#include "opt/pass.h" - +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - #ifdef SPIRV_EFFCEE bool Validate(const std::vector& bin) { @@ -48,7 +46,7 @@ bool Validate(const std::vector& bin) { return error == 0; } -void Match(const std::string& original, ir::IRContext* context, +void Match(const std::string& original, IRContext* context, bool do_validation = true) { std::vector bin; context->module()->ToBinary(&bin, true); @@ -136,18 +134,18 @@ TEST_F(LCSSATest, SimpleLCSSA) { OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor ld{f}; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; - ir::Loop* loop = ld[17]; + Loop* loop = ld[17]; EXPECT_FALSE(loop->IsLCSSA()); - opt::LoopUtils Util(context.get(), loop); + LoopUtils Util(context.get(), loop); Util.MakeLoopClosedSSA(); EXPECT_TRUE(loop->IsLCSSA()); Match(text, context.get()); @@ -222,18 +220,18 @@ TEST_F(LCSSATest, PhiReuseLCSSA) { OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor ld{f}; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; - ir::Loop* loop = ld[17]; + Loop* loop = ld[17]; EXPECT_FALSE(loop->IsLCSSA()); - opt::LoopUtils Util(context.get(), loop); + LoopUtils Util(context.get(), loop); Util.MakeLoopClosedSSA(); EXPECT_TRUE(loop->IsLCSSA()); Match(text, context.get()); @@ -321,18 +319,18 @@ TEST_F(LCSSATest, DualLoopLCSSA) { OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor ld{f}; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; - ir::Loop* loop = ld[16]; + Loop* loop = ld[16]; EXPECT_FALSE(loop->IsLCSSA()); - opt::LoopUtils Util(context.get(), loop); + LoopUtils Util(context.get(), loop); Util.MakeLoopClosedSSA(); EXPECT_TRUE(loop->IsLCSSA()); Match(text, context.get()); @@ -414,18 +412,18 @@ TEST_F(LCSSATest, PhiUserLCSSA) { OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor ld{f}; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; - ir::Loop* loop = ld[19]; + Loop* loop = ld[19]; EXPECT_FALSE(loop->IsLCSSA()); - opt::LoopUtils Util(context.get(), loop); + LoopUtils Util(context.get(), loop); Util.MakeLoopClosedSSA(); EXPECT_TRUE(loop->IsLCSSA()); Match(text, context.get()); @@ -509,18 +507,18 @@ TEST_F(LCSSATest, LCSSAWithBreak) { OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor ld{f}; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; - ir::Loop* loop = ld[19]; + Loop* loop = ld[19]; EXPECT_FALSE(loop->IsLCSSA()); - opt::LoopUtils Util(context.get(), loop); + LoopUtils Util(context.get(), loop); Util.MakeLoopClosedSSA(); EXPECT_TRUE(loop->IsLCSSA()); Match(text, context.get()); @@ -592,18 +590,18 @@ TEST_F(LCSSATest, LCSSAUseInNonEligiblePhi) { OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor ld{f}; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; - ir::Loop* loop = ld[12]; + Loop* loop = ld[12]; EXPECT_FALSE(loop->IsLCSSA()); - opt::LoopUtils Util(context.get(), loop); + LoopUtils Util(context.get(), loop); Util.MakeLoopClosedSSA(); EXPECT_TRUE(loop->IsLCSSA()); Match(text, context.get()); @@ -612,3 +610,5 @@ TEST_F(LCSSATest, LCSSAUseInNonEligiblePhi) { #endif // SPIRV_EFFCEE } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/loop_descriptions.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/loop_descriptions.cpp index f53ad0508..91dbdc6b5 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/loop_descriptions.cpp +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/loop_descriptions.cpp @@ -12,24 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include #include #include -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" -#include "opt/loop_descriptor.h" -#include "opt/pass.h" +#include "gmock/gmock.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -90,18 +89,18 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(f); + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); EXPECT_EQ(ld.NumLoops(), 1u); - ir::Loop& loop = ld.GetLoopByIndex(0); + Loop& loop = ld.GetLoopByIndex(0); EXPECT_EQ(loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 18)); EXPECT_EQ(loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 20)); EXPECT_EQ(loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 19)); @@ -187,18 +186,18 @@ TEST_F(PassClassTest, LoopWithNoPreHeader) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(f); + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); EXPECT_EQ(ld.NumLoops(), 2u); - ir::Loop* loop = ld[27]; + Loop* loop = ld[27]; EXPECT_EQ(loop->GetPreHeaderBlock(), nullptr); EXPECT_NE(loop->GetOrCreatePreHeaderBlock(), nullptr); } @@ -285,16 +284,101 @@ TEST_F(PassClassTest, NoLoop) { OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 4); - ir::LoopDescriptor ld{f}; + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor ld{context.get(), f}; EXPECT_EQ(ld.NumLoops(), 0u); } +/* +Generated from following GLSL with latch block artificially inserted to be +seperate from continue. +#version 430 +void main(void) { + float x[10]; + for (int i = 0; i < 10; ++i) { + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, LoopLatchNotContinue) { + const std::string text = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "x" + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %7 = OpTypeInt 32 1 + %8 = OpTypePointer Function %7 + %9 = OpConstant %7 0 + %10 = OpConstant %7 10 + %11 = OpTypeBool + %12 = OpTypeFloat 32 + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 10 + %15 = OpTypeArray %12 %14 + %16 = OpTypePointer Function %15 + %17 = OpTypePointer Function %12 + %18 = OpConstant %7 1 + %2 = OpFunction %5 None %6 + %19 = OpLabel + %3 = OpVariable %8 Function + %4 = OpVariable %16 Function + OpStore %3 %9 + OpBranch %20 + %20 = OpLabel + %21 = OpPhi %7 %9 %19 %22 %30 + OpLoopMerge %24 %23 None + OpBranch %25 + %25 = OpLabel + %26 = OpSLessThan %11 %21 %10 + OpBranchConditional %26 %27 %24 + %27 = OpLabel + %28 = OpConvertSToF %12 %21 + %29 = OpAccessChain %17 %4 %21 + OpStore %29 %28 + OpBranch %23 + %23 = OpLabel + %22 = OpIAdd %7 %21 %18 + OpStore %3 %22 + OpBranch %30 + %30 = OpLabel + OpBranch %20 + %24 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; + + EXPECT_EQ(ld.NumLoops(), 1u); + + Loop& loop = ld.GetLoopByIndex(0u); + + EXPECT_NE(loop.GetLatchBlock(), loop.GetContinueBlock()); + + EXPECT_EQ(loop.GetContinueBlock()->id(), 23u); + EXPECT_EQ(loop.GetLatchBlock()->id(), 30u); +} + } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/loop_fission.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/loop_fission.cpp new file mode 100644 index 000000000..e513f4253 --- /dev/null +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/loop_fission.cpp @@ -0,0 +1,3491 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/loop_fission.h" +#include "source/opt/loop_unroller.h" +#include "source/opt/loop_utils.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using FissionClassTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL + +#version 430 + +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + A[i] = B[i]; + B[i] = A[i]; + } +} + +Result should be equivalent to: + +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + A[i] = B[i]; + } + + for (int i = 0; i < 10; i++) { + B[i] = A[i]; + } +} +*/ +TEST_F(FissionClassTest, SimpleFission) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "A" +OpName %5 "B" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %20 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%29 = OpAccessChain %18 %5 %22 +%30 = OpLoad %13 %29 +%31 = OpAccessChain %18 %4 %22 +OpStore %31 %30 +%32 = OpAccessChain %18 %4 %22 +%33 = OpLoad %13 %32 +%34 = OpAccessChain %18 %5 %22 +OpStore %34 %33 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "A" +OpName %5 "B" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpBranch %35 +%35 = OpLabel +%36 = OpPhi %8 %10 %20 %47 %46 +OpLoopMerge %48 %46 None +OpBranch %37 +%37 = OpLabel +%38 = OpSLessThan %12 %36 %11 +OpBranchConditional %38 %39 %48 +%39 = OpLabel +%40 = OpAccessChain %18 %5 %36 +%41 = OpLoad %13 %40 +%42 = OpAccessChain %18 %4 %36 +OpStore %42 %41 +OpBranch %46 +%46 = OpLabel +%47 = OpIAdd %8 %36 %19 +OpBranch %35 +%48 = OpLabel +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %48 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%32 = OpAccessChain %18 %4 %22 +%33 = OpLoad %13 %32 +%34 = OpAccessChain %18 %5 %22 +OpStore %34 %33 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); + + // Check that the loop will NOT be split when provided with a pass-through + // register pressure functor which just returns false. + SinglePassRunAndCheck( + source, source, true, + [](const RegisterLiveness::RegionRegisterLiveness&) { return false; }); +} + +/* +Generated from the following GLSL + +#version 430 + +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + A[i] = B[i]; + B[i] = A[i+1]; + } +} + +This loop should not be split, as the i+1 dependence would be broken by +splitting the loop. +*/ + +TEST_F(FissionClassTest, FissionInterdependency) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "A" +OpName %5 "B" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %20 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%29 = OpAccessChain %18 %5 %22 +%30 = OpLoad %13 %29 +%31 = OpAccessChain %18 %4 %22 +OpStore %31 %30 +%32 = OpIAdd %8 %22 %19 +%33 = OpAccessChain %18 %4 %32 +%34 = OpLoad %13 %33 +%35 = OpAccessChain %18 %5 %22 +OpStore %35 %34 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +Generated from the following GLSL + +#version 430 + +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + A[i] = B[i]; + B[i+1] = A[i]; + } +} + + +This should not be split as the load B[i] is dependent on the store B[i+1] +*/ +TEST_F(FissionClassTest, FissionInterdependency2) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "A" +OpName %5 "B" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %20 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%29 = OpAccessChain %18 %5 %22 +%30 = OpLoad %13 %29 +%31 = OpAccessChain %18 %4 %22 +OpStore %31 %30 +%32 = OpIAdd %8 %22 %19 +%33 = OpAccessChain %18 %4 %22 +%34 = OpLoad %13 %33 +%35 = OpAccessChain %18 %5 %32 +OpStore %35 %34 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +void main(void) { + float A[10]; + float B[10]; + float C[10] + float D[10] + for (int i = 0; i < 10; i++) { + A[i] = B[i]; + B[i] = A[i]; + C[i] = D[i]; + D[i] = C[i]; + } +} + +This should be split into the equivalent of: + + for (int i = 0; i < 10; i++) { + A[i] = B[i]; + B[i] = A[i]; + } + for (int i = 0; i < 10; i++) { + C[i] = D[i]; + D[i] = C[i]; + } + +We then check that the loop is broken into four for loops like so, if the pass +is run twice: + for (int i = 0; i < 10; i++) + A[i] = B[i]; + for (int i = 0; i < 10; i++) + B[i] = A[i]; + for (int i = 0; i < 10; i++) + C[i] = D[i]; + for (int i = 0; i < 10; i++) + D[i] = C[i]; + +*/ + +TEST_F(FissionClassTest, FissionMultipleLoadStores) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "A" + OpName %5 "B" + OpName %6 "C" + OpName %7 "D" + %8 = OpTypeVoid + %9 = OpTypeFunction %8 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %12 = OpConstant %10 0 + %13 = OpConstant %10 10 + %14 = OpTypeBool + %15 = OpTypeFloat 32 + %16 = OpTypeInt 32 0 + %17 = OpConstant %16 10 + %18 = OpTypeArray %15 %17 + %19 = OpTypePointer Function %18 + %20 = OpTypePointer Function %15 + %21 = OpConstant %10 1 + %2 = OpFunction %8 None %9 + %22 = OpLabel + %3 = OpVariable %11 Function + %4 = OpVariable %19 Function + %5 = OpVariable %19 Function + %6 = OpVariable %19 Function + %7 = OpVariable %19 Function + OpBranch %23 + %23 = OpLabel + %24 = OpPhi %10 %12 %22 %25 %26 + OpLoopMerge %27 %26 None + OpBranch %28 + %28 = OpLabel + %29 = OpSLessThan %14 %24 %13 + OpBranchConditional %29 %30 %27 + %30 = OpLabel + %31 = OpAccessChain %20 %5 %24 + %32 = OpLoad %15 %31 + %33 = OpAccessChain %20 %4 %24 + OpStore %33 %32 + %34 = OpAccessChain %20 %4 %24 + %35 = OpLoad %15 %34 + %36 = OpAccessChain %20 %5 %24 + OpStore %36 %35 + %37 = OpAccessChain %20 %7 %24 + %38 = OpLoad %15 %37 + %39 = OpAccessChain %20 %6 %24 + OpStore %39 %38 + %40 = OpAccessChain %20 %6 %24 + %41 = OpLoad %15 %40 + %42 = OpAccessChain %20 %7 %24 + OpStore %42 %41 + OpBranch %26 + %26 = OpLabel + %25 = OpIAdd %10 %24 %21 + OpBranch %23 + %27 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "A" +OpName %5 "B" +OpName %6 "C" +OpName %7 "D" +%8 = OpTypeVoid +%9 = OpTypeFunction %8 +%10 = OpTypeInt 32 1 +%11 = OpTypePointer Function %10 +%12 = OpConstant %10 0 +%13 = OpConstant %10 10 +%14 = OpTypeBool +%15 = OpTypeFloat 32 +%16 = OpTypeInt 32 0 +%17 = OpConstant %16 10 +%18 = OpTypeArray %15 %17 +%19 = OpTypePointer Function %18 +%20 = OpTypePointer Function %15 +%21 = OpConstant %10 1 +%2 = OpFunction %8 None %9 +%22 = OpLabel +%3 = OpVariable %11 Function +%4 = OpVariable %19 Function +%5 = OpVariable %19 Function +%6 = OpVariable %19 Function +%7 = OpVariable %19 Function +OpBranch %43 +%43 = OpLabel +%44 = OpPhi %10 %12 %22 %61 %60 +OpLoopMerge %62 %60 None +OpBranch %45 +%45 = OpLabel +%46 = OpSLessThan %14 %44 %13 +OpBranchConditional %46 %47 %62 +%47 = OpLabel +%48 = OpAccessChain %20 %5 %44 +%49 = OpLoad %15 %48 +%50 = OpAccessChain %20 %4 %44 +OpStore %50 %49 +%51 = OpAccessChain %20 %4 %44 +%52 = OpLoad %15 %51 +%53 = OpAccessChain %20 %5 %44 +OpStore %53 %52 +OpBranch %60 +%60 = OpLabel +%61 = OpIAdd %10 %44 %21 +OpBranch %43 +%62 = OpLabel +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %10 %12 %62 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %14 %24 %13 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +%37 = OpAccessChain %20 %7 %24 +%38 = OpLoad %15 %37 +%39 = OpAccessChain %20 %6 %24 +OpStore %39 %38 +%40 = OpAccessChain %20 %6 %24 +%41 = OpLoad %15 %40 +%42 = OpAccessChain %20 %7 %24 +OpStore %42 %41 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %10 %24 %21 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + +const std::string expected_multiple_passes = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "A" +OpName %5 "B" +OpName %6 "C" +OpName %7 "D" +%8 = OpTypeVoid +%9 = OpTypeFunction %8 +%10 = OpTypeInt 32 1 +%11 = OpTypePointer Function %10 +%12 = OpConstant %10 0 +%13 = OpConstant %10 10 +%14 = OpTypeBool +%15 = OpTypeFloat 32 +%16 = OpTypeInt 32 0 +%17 = OpConstant %16 10 +%18 = OpTypeArray %15 %17 +%19 = OpTypePointer Function %18 +%20 = OpTypePointer Function %15 +%21 = OpConstant %10 1 +%2 = OpFunction %8 None %9 +%22 = OpLabel +%3 = OpVariable %11 Function +%4 = OpVariable %19 Function +%5 = OpVariable %19 Function +%6 = OpVariable %19 Function +%7 = OpVariable %19 Function +OpBranch %63 +%63 = OpLabel +%64 = OpPhi %10 %12 %22 %75 %74 +OpLoopMerge %76 %74 None +OpBranch %65 +%65 = OpLabel +%66 = OpSLessThan %14 %64 %13 +OpBranchConditional %66 %67 %76 +%67 = OpLabel +%68 = OpAccessChain %20 %5 %64 +%69 = OpLoad %15 %68 +%70 = OpAccessChain %20 %4 %64 +OpStore %70 %69 +OpBranch %74 +%74 = OpLabel +%75 = OpIAdd %10 %64 %21 +OpBranch %63 +%76 = OpLabel +OpBranch %43 +%43 = OpLabel +%44 = OpPhi %10 %12 %76 %61 %60 +OpLoopMerge %62 %60 None +OpBranch %45 +%45 = OpLabel +%46 = OpSLessThan %14 %44 %13 +OpBranchConditional %46 %47 %62 +%47 = OpLabel +%51 = OpAccessChain %20 %4 %44 +%52 = OpLoad %15 %51 +%53 = OpAccessChain %20 %5 %44 +OpStore %53 %52 +OpBranch %60 +%60 = OpLabel +%61 = OpIAdd %10 %44 %21 +OpBranch %43 +%62 = OpLabel +OpBranch %77 +%77 = OpLabel +%78 = OpPhi %10 %12 %62 %89 %88 +OpLoopMerge %90 %88 None +OpBranch %79 +%79 = OpLabel +%80 = OpSLessThan %14 %78 %13 +OpBranchConditional %80 %81 %90 +%81 = OpLabel +%82 = OpAccessChain %20 %7 %78 +%83 = OpLoad %15 %82 +%84 = OpAccessChain %20 %6 %78 +OpStore %84 %83 +OpBranch %88 +%88 = OpLabel +%89 = OpIAdd %10 %78 %21 +OpBranch %77 +%90 = OpLabel +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %10 %12 %90 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %14 %24 %13 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +%40 = OpAccessChain %20 %6 %24 +%41 = OpLoad %15 %40 +%42 = OpAccessChain %20 %7 %24 +OpStore %42 %41 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %10 %24 %21 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on +std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); +Module* module = context->module(); +EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + +SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); +SinglePassRunAndCheck(source, expected, true); + +// By passing 1 as argument we are using the constructor which makes the +// critera to split the loop be if the registers in the loop exceede 1. By +// using this constructor we are also enabling multiple passes (disabled by +// default). +SinglePassRunAndCheck(source, expected_multiple_passes, true, + 1); +} + +/* +#version 430 +void main(void) { + int accumulator = 0; + float X[10]; + float Y[10]; + + for (int i = 0; i < 10; i++) { + X[i] = Y[i]; + Y[i] = X[i]; + accumulator += i; + } +} + +This should be split into the equivalent of: + +#version 430 +void main(void) { + int accumulator = 0; + float X[10]; + float Y[10]; + + for (int i = 0; i < 10; i++) { + X[i] = Y[i]; + } + for (int i = 0; i < 10; i++) { + Y[i] = X[i]; + accumulator += i; + } +} +*/ +TEST_F(FissionClassTest, FissionWithAccumulator) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "accumulator" + OpName %4 "i" + OpName %5 "X" + OpName %6 "Y" + %7 = OpTypeVoid + %8 = OpTypeFunction %7 + %9 = OpTypeInt 32 1 + %10 = OpTypePointer Function %9 + %11 = OpConstant %9 0 + %12 = OpConstant %9 10 + %13 = OpTypeBool + %14 = OpTypeFloat 32 + %15 = OpTypeInt 32 0 + %16 = OpConstant %15 10 + %17 = OpTypeArray %14 %16 + %18 = OpTypePointer Function %17 + %19 = OpTypePointer Function %14 + %20 = OpConstant %9 1 + %2 = OpFunction %7 None %8 + %21 = OpLabel + %3 = OpVariable %10 Function + %4 = OpVariable %10 Function + %5 = OpVariable %18 Function + %6 = OpVariable %18 Function + OpBranch %22 + %22 = OpLabel + %23 = OpPhi %9 %11 %21 %24 %25 + %26 = OpPhi %9 %11 %21 %27 %25 + OpLoopMerge %28 %25 None + OpBranch %29 + %29 = OpLabel + %30 = OpSLessThan %13 %26 %12 + OpBranchConditional %30 %31 %28 + %31 = OpLabel + %32 = OpAccessChain %19 %6 %26 + %33 = OpLoad %14 %32 + %34 = OpAccessChain %19 %5 %26 + OpStore %34 %33 + %35 = OpAccessChain %19 %5 %26 + %36 = OpLoad %14 %35 + %37 = OpAccessChain %19 %6 %26 + OpStore %37 %36 + %24 = OpIAdd %9 %23 %26 + OpBranch %25 + %25 = OpLabel + %27 = OpIAdd %9 %26 %20 + OpBranch %22 + %28 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "accumulator" +OpName %4 "i" +OpName %5 "X" +OpName %6 "Y" +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypePointer Function %17 +%19 = OpTypePointer Function %14 +%20 = OpConstant %9 1 +%2 = OpFunction %7 None %8 +%21 = OpLabel +%3 = OpVariable %10 Function +%4 = OpVariable %10 Function +%5 = OpVariable %18 Function +%6 = OpVariable %18 Function +OpBranch %38 +%38 = OpLabel +%40 = OpPhi %9 %11 %21 %52 %51 +OpLoopMerge %53 %51 None +OpBranch %41 +%41 = OpLabel +%42 = OpSLessThan %13 %40 %12 +OpBranchConditional %42 %43 %53 +%43 = OpLabel +%44 = OpAccessChain %19 %6 %40 +%45 = OpLoad %14 %44 +%46 = OpAccessChain %19 %5 %40 +OpStore %46 %45 +OpBranch %51 +%51 = OpLabel +%52 = OpIAdd %9 %40 %20 +OpBranch %38 +%53 = OpLabel +OpBranch %22 +%22 = OpLabel +%23 = OpPhi %9 %11 %53 %24 %25 +%26 = OpPhi %9 %11 %53 %27 %25 +OpLoopMerge %28 %25 None +OpBranch %29 +%29 = OpLabel +%30 = OpSLessThan %13 %26 %12 +OpBranchConditional %30 %31 %28 +%31 = OpLabel +%35 = OpAccessChain %19 %5 %26 +%36 = OpLoad %14 %35 +%37 = OpAccessChain %19 %6 %26 +OpStore %37 %36 +%24 = OpIAdd %9 %23 %26 +OpBranch %25 +%25 = OpLabel +%27 = OpIAdd %9 %26 %20 +OpBranch %22 +%28 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +/* +Generated from the following glsl: + +#version 430 +layout(location=0) out float x; +layout(location=1) out float y; + +void main(void) { + float accumulator_1 = 0; + float accumulator_2 = 0; + for (int i = 0; i < 10; i++) { + accumulator_1 += i; + accumulator_2 += i; + } + + x = accumulator_1; + y = accumulator_2; +} + +Should be split into equivalent of: + +void main(void) { + float accumulator_1 = 0; + float accumulator_2 = 0; + for (int i = 0; i < 10; i++) { + accumulator_1 += i; + } + + for (int i = 0; i < 10; i++) { + accumulator_2 += i; + } + x = accumulator_1; + y = accumulator_2; +} + +*/ +TEST_F(FissionClassTest, FissionWithPhisUsedOutwithLoop) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 %4 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %5 "accumulator_1" + OpName %6 "accumulator_2" + OpName %7 "i" + OpName %3 "x" + OpName %4 "y" + OpDecorate %3 Location 0 + OpDecorate %4 Location 1 + %8 = OpTypeVoid + %9 = OpTypeFunction %8 + %10 = OpTypeFloat 32 + %11 = OpTypePointer Function %10 + %12 = OpConstant %10 0 + %13 = OpTypeInt 32 1 + %14 = OpTypePointer Function %13 + %15 = OpConstant %13 0 + %16 = OpConstant %13 10 + %17 = OpTypeBool + %18 = OpConstant %13 1 + %19 = OpTypePointer Output %10 + %3 = OpVariable %19 Output + %4 = OpVariable %19 Output + %2 = OpFunction %8 None %9 + %20 = OpLabel + %5 = OpVariable %11 Function + %6 = OpVariable %11 Function + %7 = OpVariable %14 Function + OpBranch %21 + %21 = OpLabel + %22 = OpPhi %10 %12 %20 %23 %24 + %25 = OpPhi %10 %12 %20 %26 %24 + %27 = OpPhi %13 %15 %20 %28 %24 + OpLoopMerge %29 %24 None + OpBranch %30 + %30 = OpLabel + %31 = OpSLessThan %17 %27 %16 + OpBranchConditional %31 %32 %29 + %32 = OpLabel + %33 = OpConvertSToF %10 %27 + %26 = OpFAdd %10 %25 %33 + %34 = OpConvertSToF %10 %27 + %23 = OpFAdd %10 %22 %34 + OpBranch %24 + %24 = OpLabel + %28 = OpIAdd %13 %27 %18 + OpStore %7 %28 + OpBranch %21 + %29 = OpLabel + OpStore %3 %25 + OpStore %4 %22 + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 %4 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %5 "accumulator_1" +OpName %6 "accumulator_2" +OpName %7 "i" +OpName %3 "x" +OpName %4 "y" +OpDecorate %3 Location 0 +OpDecorate %4 Location 1 +%8 = OpTypeVoid +%9 = OpTypeFunction %8 +%10 = OpTypeFloat 32 +%11 = OpTypePointer Function %10 +%12 = OpConstant %10 0 +%13 = OpTypeInt 32 1 +%14 = OpTypePointer Function %13 +%15 = OpConstant %13 0 +%16 = OpConstant %13 10 +%17 = OpTypeBool +%18 = OpConstant %13 1 +%19 = OpTypePointer Output %10 +%3 = OpVariable %19 Output +%4 = OpVariable %19 Output +%2 = OpFunction %8 None %9 +%20 = OpLabel +%5 = OpVariable %11 Function +%6 = OpVariable %11 Function +%7 = OpVariable %14 Function +OpBranch %35 +%35 = OpLabel +%37 = OpPhi %10 %12 %20 %43 %46 +%38 = OpPhi %13 %15 %20 %47 %46 +OpLoopMerge %48 %46 None +OpBranch %39 +%39 = OpLabel +%40 = OpSLessThan %17 %38 %16 +OpBranchConditional %40 %41 %48 +%41 = OpLabel +%42 = OpConvertSToF %10 %38 +%43 = OpFAdd %10 %37 %42 +OpBranch %46 +%46 = OpLabel +%47 = OpIAdd %13 %38 %18 +OpStore %7 %47 +OpBranch %35 +%48 = OpLabel +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %10 %12 %48 %23 %24 +%27 = OpPhi %13 %15 %48 %28 %24 +OpLoopMerge %29 %24 None +OpBranch %30 +%30 = OpLabel +%31 = OpSLessThan %17 %27 %16 +OpBranchConditional %31 %32 %29 +%32 = OpLabel +%34 = OpConvertSToF %10 %27 +%23 = OpFAdd %10 %22 %34 +OpBranch %24 +%24 = OpLabel +%28 = OpIAdd %13 %27 %18 +OpStore %7 %28 +OpBranch %21 +%29 = OpLabel +OpStore %3 %37 +OpStore %4 %22 +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +/* +#version 430 +void main(void) { + float A[10][10]; + float B[10][10]; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + A[i][j] = B[i][j]; + B[i][j] = A[i][j]; + } + } +} + +Should be split into equivalent of: + +#version 430 +void main(void) { + float A[10][10]; + float B[10][10]; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + A[i][j] = B[i][j]; + } + for (int j = 0; j < 10; j++) { + B[i][j] = A[i][j]; + } + } +} + + +*/ +TEST_F(FissionClassTest, FissionNested) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "j" + OpName %5 "A" + OpName %6 "B" + %7 = OpTypeVoid + %8 = OpTypeFunction %7 + %9 = OpTypeInt 32 1 + %10 = OpTypePointer Function %9 + %11 = OpConstant %9 0 + %12 = OpConstant %9 10 + %13 = OpTypeBool + %14 = OpTypeFloat 32 + %15 = OpTypeInt 32 0 + %16 = OpConstant %15 10 + %17 = OpTypeArray %14 %16 + %18 = OpTypeArray %17 %16 + %19 = OpTypePointer Function %18 + %20 = OpTypePointer Function %14 + %21 = OpConstant %9 1 + %2 = OpFunction %7 None %8 + %22 = OpLabel + %3 = OpVariable %10 Function + %4 = OpVariable %10 Function + %5 = OpVariable %19 Function + %6 = OpVariable %19 Function + OpStore %3 %11 + OpBranch %23 + %23 = OpLabel + %24 = OpPhi %9 %11 %22 %25 %26 + OpLoopMerge %27 %26 None + OpBranch %28 + %28 = OpLabel + %29 = OpSLessThan %13 %24 %12 + OpBranchConditional %29 %30 %27 + %30 = OpLabel + OpStore %4 %11 + OpBranch %31 + %31 = OpLabel + %32 = OpPhi %9 %11 %30 %33 %34 + OpLoopMerge %35 %34 None + OpBranch %36 + %36 = OpLabel + %37 = OpSLessThan %13 %32 %12 + OpBranchConditional %37 %38 %35 + %38 = OpLabel + %39 = OpAccessChain %20 %6 %24 %32 + %40 = OpLoad %14 %39 + %41 = OpAccessChain %20 %5 %24 %32 + OpStore %41 %40 + %42 = OpAccessChain %20 %5 %24 %32 + %43 = OpLoad %14 %42 + %44 = OpAccessChain %20 %6 %24 %32 + OpStore %44 %43 + OpBranch %34 + %34 = OpLabel + %33 = OpIAdd %9 %32 %21 + OpStore %4 %33 + OpBranch %31 + %35 = OpLabel + OpBranch %26 + %26 = OpLabel + %25 = OpIAdd %9 %24 %21 + OpStore %3 %25 + OpBranch %23 + %27 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "j" +OpName %5 "A" +OpName %6 "B" +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypeArray %17 %16 +%19 = OpTypePointer Function %18 +%20 = OpTypePointer Function %14 +%21 = OpConstant %9 1 +%2 = OpFunction %7 None %8 +%22 = OpLabel +%3 = OpVariable %10 Function +%4 = OpVariable %10 Function +%5 = OpVariable %19 Function +%6 = OpVariable %19 Function +OpStore %3 %11 +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %9 %11 %22 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %13 %24 %12 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +OpStore %4 %11 +OpBranch %45 +%45 = OpLabel +%46 = OpPhi %9 %11 %30 %57 %56 +OpLoopMerge %58 %56 None +OpBranch %47 +%47 = OpLabel +%48 = OpSLessThan %13 %46 %12 +OpBranchConditional %48 %49 %58 +%49 = OpLabel +%50 = OpAccessChain %20 %6 %24 %46 +%51 = OpLoad %14 %50 +%52 = OpAccessChain %20 %5 %24 %46 +OpStore %52 %51 +OpBranch %56 +%56 = OpLabel +%57 = OpIAdd %9 %46 %21 +OpStore %4 %57 +OpBranch %45 +%58 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpPhi %9 %11 %58 %33 %34 +OpLoopMerge %35 %34 None +OpBranch %36 +%36 = OpLabel +%37 = OpSLessThan %13 %32 %12 +OpBranchConditional %37 %38 %35 +%38 = OpLabel +%42 = OpAccessChain %20 %5 %24 %32 +%43 = OpLoad %14 %42 +%44 = OpAccessChain %20 %6 %24 %32 +OpStore %44 %43 +OpBranch %34 +%34 = OpLabel +%33 = OpIAdd %9 %32 %21 +OpStore %4 %33 +OpBranch %31 +%35 = OpLabel +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %9 %24 %21 +OpStore %3 %25 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +/* +#version 430 +void main(void) { + int accumulator = 0; + float A[10]; + float B[10]; + float C[10]; + + for (int i = 0; i < 10; i++) { + int c = C[i]; + A[i] = B[i]; + B[i] = A[i] + c; + } +} + +This loop should not be split as we would have to break the order of the loads +to do so. It would be grouped into two sets: + +1 + int c = C[i]; + B[i] = A[i] + c; + +2 + A[i] = B[i]; + +To keep the load C[i] in the same order we would need to put B[i] ahead of that +*/ +TEST_F(FissionClassTest, FissionLoad) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "c" +OpName %5 "C" +OpName %6 "A" +OpName %7 "B" +%8 = OpTypeVoid +%9 = OpTypeFunction %8 +%10 = OpTypeInt 32 1 +%11 = OpTypePointer Function %10 +%12 = OpConstant %10 0 +%13 = OpConstant %10 10 +%14 = OpTypeBool +%15 = OpTypeFloat 32 +%16 = OpTypePointer Function %15 +%17 = OpTypeInt 32 0 +%18 = OpConstant %17 10 +%19 = OpTypeArray %15 %18 +%20 = OpTypePointer Function %19 +%21 = OpConstant %10 1 +%2 = OpFunction %8 None %9 +%22 = OpLabel +%3 = OpVariable %11 Function +%4 = OpVariable %16 Function +%5 = OpVariable %20 Function +%6 = OpVariable %20 Function +%7 = OpVariable %20 Function +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %10 %12 %22 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %14 %24 %13 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +%31 = OpAccessChain %16 %5 %24 +%32 = OpLoad %15 %31 +OpStore %4 %32 +%33 = OpAccessChain %16 %7 %24 +%34 = OpLoad %15 %33 +%35 = OpAccessChain %16 %6 %24 +OpStore %35 %34 +%36 = OpAccessChain %16 %6 %24 +%37 = OpLoad %15 %36 +%38 = OpFAdd %15 %37 %32 +%39 = OpAccessChain %16 %7 %24 +OpStore %39 %38 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %10 %24 %21 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +layout(location=0) flat in int condition; +void main(void) { + float A[10]; + float B[10]; + + for (int i = 0; i < 10; i++) { + if (condition == 1) + A[i] = B[i]; + else + B[i] = A[i]; + } +} + + +When this is split we leave the condition check and control flow inplace and +leave its removal for dead code elimination. + +#version 430 +layout(location=0) flat in int condition; +void main(void) { + float A[10]; + float B[10]; + + for (int i = 0; i < 10; i++) { + if (condition == 1) + A[i] = B[i]; + else + ; + } + for (int i = 0; i < 10; i++) { + if (condition == 1) + ; + else + B[i] = A[i]; + } +} + + +*/ +TEST_F(FissionClassTest, FissionControlFlow) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %4 "i" + OpName %3 "condition" + OpName %5 "A" + OpName %6 "B" + OpDecorate %3 Flat + OpDecorate %3 Location 0 + %7 = OpTypeVoid + %8 = OpTypeFunction %7 + %9 = OpTypeInt 32 1 + %10 = OpTypePointer Function %9 + %11 = OpConstant %9 0 + %12 = OpConstant %9 10 + %13 = OpTypeBool + %14 = OpTypePointer Input %9 + %3 = OpVariable %14 Input + %15 = OpConstant %9 1 + %16 = OpTypeFloat 32 + %17 = OpTypeInt 32 0 + %18 = OpConstant %17 10 + %19 = OpTypeArray %16 %18 + %20 = OpTypePointer Function %19 + %21 = OpTypePointer Function %16 + %2 = OpFunction %7 None %8 + %22 = OpLabel + %4 = OpVariable %10 Function + %5 = OpVariable %20 Function + %6 = OpVariable %20 Function + %31 = OpLoad %9 %3 + OpStore %4 %11 + OpBranch %23 + %23 = OpLabel + %24 = OpPhi %9 %11 %22 %25 %26 + OpLoopMerge %27 %26 None + OpBranch %28 + %28 = OpLabel + %29 = OpSLessThan %13 %24 %12 + OpBranchConditional %29 %30 %27 + %30 = OpLabel + %32 = OpIEqual %13 %31 %15 + OpSelectionMerge %33 None + OpBranchConditional %32 %34 %35 + %34 = OpLabel + %36 = OpAccessChain %21 %6 %24 + %37 = OpLoad %16 %36 + %38 = OpAccessChain %21 %5 %24 + OpStore %38 %37 + OpBranch %33 + %35 = OpLabel + %39 = OpAccessChain %21 %5 %24 + %40 = OpLoad %16 %39 + %41 = OpAccessChain %21 %6 %24 + OpStore %41 %40 + OpBranch %33 + %33 = OpLabel + OpBranch %26 + %26 = OpLabel + %25 = OpIAdd %9 %24 %15 + OpStore %4 %25 + OpBranch %23 + %27 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %4 "i" +OpName %3 "condition" +OpName %5 "A" +OpName %6 "B" +OpDecorate %3 Flat +OpDecorate %3 Location 0 +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypePointer Input %9 +%3 = OpVariable %14 Input +%15 = OpConstant %9 1 +%16 = OpTypeFloat 32 +%17 = OpTypeInt 32 0 +%18 = OpConstant %17 10 +%19 = OpTypeArray %16 %18 +%20 = OpTypePointer Function %19 +%21 = OpTypePointer Function %16 +%2 = OpFunction %7 None %8 +%22 = OpLabel +%4 = OpVariable %10 Function +%5 = OpVariable %20 Function +%6 = OpVariable %20 Function +%23 = OpLoad %9 %3 +OpStore %4 %11 +OpBranch %42 +%42 = OpLabel +%43 = OpPhi %9 %11 %22 %58 %57 +OpLoopMerge %59 %57 None +OpBranch %44 +%44 = OpLabel +%45 = OpSLessThan %13 %43 %12 +OpBranchConditional %45 %46 %59 +%46 = OpLabel +%47 = OpIEqual %13 %23 %15 +OpSelectionMerge %56 None +OpBranchConditional %47 %52 %48 +%48 = OpLabel +OpBranch %56 +%52 = OpLabel +%53 = OpAccessChain %21 %6 %43 +%54 = OpLoad %16 %53 +%55 = OpAccessChain %21 %5 %43 +OpStore %55 %54 +OpBranch %56 +%56 = OpLabel +OpBranch %57 +%57 = OpLabel +%58 = OpIAdd %9 %43 %15 +OpStore %4 %58 +OpBranch %42 +%59 = OpLabel +OpBranch %24 +%24 = OpLabel +%25 = OpPhi %9 %11 %59 %26 %27 +OpLoopMerge %28 %27 None +OpBranch %29 +%29 = OpLabel +%30 = OpSLessThan %13 %25 %12 +OpBranchConditional %30 %31 %28 +%31 = OpLabel +%32 = OpIEqual %13 %23 %15 +OpSelectionMerge %33 None +OpBranchConditional %32 %34 %35 +%34 = OpLabel +OpBranch %33 +%35 = OpLabel +%39 = OpAccessChain %21 %5 %25 +%40 = OpLoad %16 %39 +%41 = OpAccessChain %21 %6 %25 +OpStore %41 %40 +OpBranch %33 +%33 = OpLabel +OpBranch %27 +%27 = OpLabel +%26 = OpIAdd %9 %25 %15 +OpStore %4 %26 +OpBranch %24 +%28 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +/* +#version 430 +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + if (i == 1) + B[i] = A[i]; + else if (i == 2) + A[i] = B[i]; + else + A[i] = 0; + } +} + +After running the pass with multiple splits enabled (via register threshold of +1) we expect the equivalent of: + +#version 430 +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + if (i == 1) + B[i] = A[i]; + else if (i == 2) + else + } + for (int i = 0; i < 10; i++) { + if (i == 1) + else if (i == 2) + A[i] = B[i]; + else + } + for (int i = 0; i < 10; i++) { + if (i == 1) + else if (i == 2) + else + A[i] = 0; + } + +} + +*/ +TEST_F(FissionClassTest, FissionControlFlow2) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "B" + OpName %5 "A" + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpConstant %8 1 + %14 = OpTypeFloat 32 + %15 = OpTypeInt 32 0 + %16 = OpConstant %15 10 + %17 = OpTypeArray %14 %16 + %18 = OpTypePointer Function %17 + %19 = OpTypePointer Function %14 + %20 = OpConstant %8 2 + %21 = OpConstant %14 0 + %2 = OpFunction %6 None %7 + %22 = OpLabel + %3 = OpVariable %9 Function + %4 = OpVariable %18 Function + %5 = OpVariable %18 Function + OpStore %3 %10 + OpBranch %23 + %23 = OpLabel + %24 = OpPhi %8 %10 %22 %25 %26 + OpLoopMerge %27 %26 None + OpBranch %28 + %28 = OpLabel + %29 = OpSLessThan %12 %24 %11 + OpBranchConditional %29 %30 %27 + %30 = OpLabel + %31 = OpIEqual %12 %24 %13 + OpSelectionMerge %32 None + OpBranchConditional %31 %33 %34 + %33 = OpLabel + %35 = OpAccessChain %19 %5 %24 + %36 = OpLoad %14 %35 + %37 = OpAccessChain %19 %4 %24 + OpStore %37 %36 + OpBranch %32 + %34 = OpLabel + %38 = OpIEqual %12 %24 %20 + OpSelectionMerge %39 None + OpBranchConditional %38 %40 %41 + %40 = OpLabel + %42 = OpAccessChain %19 %4 %24 + %43 = OpLoad %14 %42 + %44 = OpAccessChain %19 %5 %24 + OpStore %44 %43 + OpBranch %39 + %41 = OpLabel + %45 = OpAccessChain %19 %5 %24 + OpStore %45 %21 + OpBranch %39 + %39 = OpLabel + OpBranch %32 + %32 = OpLabel + OpBranch %26 + %26 = OpLabel + %25 = OpIAdd %8 %24 %13 + OpStore %3 %25 + OpBranch %23 + %27 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpConstant %8 1 +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypePointer Function %17 +%19 = OpTypePointer Function %14 +%20 = OpConstant %8 2 +%21 = OpConstant %14 0 +%2 = OpFunction %6 None %7 +%22 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %18 Function +%5 = OpVariable %18 Function +OpStore %3 %10 +OpBranch %46 +%46 = OpLabel +%47 = OpPhi %8 %10 %22 %67 %66 +OpLoopMerge %68 %66 None +OpBranch %48 +%48 = OpLabel +%49 = OpSLessThan %12 %47 %11 +OpBranchConditional %49 %50 %68 +%50 = OpLabel +%51 = OpIEqual %12 %47 %13 +OpSelectionMerge %65 None +OpBranchConditional %51 %61 %52 +%52 = OpLabel +%53 = OpIEqual %12 %47 %20 +OpSelectionMerge %60 None +OpBranchConditional %53 %56 %54 +%54 = OpLabel +OpBranch %60 +%56 = OpLabel +OpBranch %60 +%60 = OpLabel +OpBranch %65 +%61 = OpLabel +%62 = OpAccessChain %19 %5 %47 +%63 = OpLoad %14 %62 +%64 = OpAccessChain %19 %4 %47 +OpStore %64 %63 +OpBranch %65 +%65 = OpLabel +OpBranch %66 +%66 = OpLabel +%67 = OpIAdd %8 %47 %13 +OpStore %3 %67 +OpBranch %46 +%68 = OpLabel +OpBranch %69 +%69 = OpLabel +%70 = OpPhi %8 %10 %68 %87 %86 +OpLoopMerge %88 %86 None +OpBranch %71 +%71 = OpLabel +%72 = OpSLessThan %12 %70 %11 +OpBranchConditional %72 %73 %88 +%73 = OpLabel +%74 = OpIEqual %12 %70 %13 +OpSelectionMerge %85 None +OpBranchConditional %74 %84 %75 +%75 = OpLabel +%76 = OpIEqual %12 %70 %20 +OpSelectionMerge %83 None +OpBranchConditional %76 %79 %77 +%77 = OpLabel +OpBranch %83 +%79 = OpLabel +%80 = OpAccessChain %19 %4 %70 +%81 = OpLoad %14 %80 +%82 = OpAccessChain %19 %5 %70 +OpStore %82 %81 +OpBranch %83 +%83 = OpLabel +OpBranch %85 +%84 = OpLabel +OpBranch %85 +%85 = OpLabel +OpBranch %86 +%86 = OpLabel +%87 = OpIAdd %8 %70 %13 +OpStore %3 %87 +OpBranch %69 +%88 = OpLabel +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %8 %10 %88 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %12 %24 %11 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +%31 = OpIEqual %12 %24 %13 +OpSelectionMerge %32 None +OpBranchConditional %31 %33 %34 +%33 = OpLabel +OpBranch %32 +%34 = OpLabel +%38 = OpIEqual %12 %24 %20 +OpSelectionMerge %39 None +OpBranchConditional %38 %40 %41 +%40 = OpLabel +OpBranch %39 +%41 = OpLabel +%45 = OpAccessChain %19 %5 %24 +OpStore %45 %21 +OpBranch %39 +%39 = OpLabel +OpBranch %32 +%32 = OpLabel +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %8 %24 %13 +OpStore %3 %25 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true, 1); +} + +/* +#version 430 +layout(location=0) flat in int condition; +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + B[i] = A[i]; + memoryBarrier(); + A[i] = B[i]; + } +} + +This should not be split due to the memory barrier. +*/ +TEST_F(FissionClassTest, FissionBarrier) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %4 "i" +OpName %5 "B" +OpName %6 "A" +OpName %3 "condition" +OpDecorate %3 Flat +OpDecorate %3 Location 0 +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypePointer Function %17 +%19 = OpTypePointer Function %14 +%20 = OpConstant %15 1 +%21 = OpConstant %15 4048 +%22 = OpConstant %9 1 +%23 = OpTypePointer Input %9 +%3 = OpVariable %23 Input +%2 = OpFunction %7 None %8 +%24 = OpLabel +%4 = OpVariable %10 Function +%5 = OpVariable %18 Function +%6 = OpVariable %18 Function +OpStore %4 %11 +OpBranch %25 +%25 = OpLabel +%26 = OpPhi %9 %11 %24 %27 %28 +OpLoopMerge %29 %28 None +OpBranch %30 +%30 = OpLabel +%31 = OpSLessThan %13 %26 %12 +OpBranchConditional %31 %32 %29 +%32 = OpLabel +%33 = OpAccessChain %19 %6 %26 +%34 = OpLoad %14 %33 +%35 = OpAccessChain %19 %5 %26 +OpStore %35 %34 +OpMemoryBarrier %20 %21 +%36 = OpAccessChain %19 %5 %26 +%37 = OpLoad %14 %36 +%38 = OpAccessChain %19 %6 %26 +OpStore %38 %37 +OpBranch %28 +%28 = OpLabel +%27 = OpIAdd %9 %26 %22 +OpStore %4 %27 +OpBranch %25 +%29 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + B[i] = A[i]; + if ( i== 1) + break; + A[i] = B[i]; + } +} + +This should not be split due to the break. +*/ +TEST_F(FissionClassTest, FissionBreak) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpStore %3 %10 +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %20 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%29 = OpAccessChain %18 %5 %22 +%30 = OpLoad %13 %29 +%31 = OpAccessChain %18 %4 %22 +OpStore %31 %30 +%32 = OpIEqual %12 %22 %19 +OpSelectionMerge %33 None +OpBranchConditional %32 %34 %33 +%34 = OpLabel +OpBranch %25 +%33 = OpLabel +%35 = OpAccessChain %18 %4 %22 +%36 = OpLoad %13 %35 +%37 = OpAccessChain %18 %5 %22 +OpStore %37 %36 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpStore %3 %23 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + B[i] = A[i]; + if ( i== 1) + continue; + A[i] = B[i]; + } +} + +This loop should be split into: + + for (int i = 0; i < 10; i++) { + B[i] = A[i]; + if ( i== 1) + continue; + } + for (int i = 0; i < 10; i++) { + if ( i== 1) + continue; + A[i] = B[i]; + } +The continue block in the first loop is left to DCE. +} + + +*/ +TEST_F(FissionClassTest, FissionContinue) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpStore %3 %10 +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %20 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%29 = OpAccessChain %18 %5 %22 +%30 = OpLoad %13 %29 +%31 = OpAccessChain %18 %4 %22 +OpStore %31 %30 +%32 = OpIEqual %12 %22 %19 +OpSelectionMerge %33 None +OpBranchConditional %32 %34 %33 +%34 = OpLabel +OpBranch %24 +%33 = OpLabel +%35 = OpAccessChain %18 %4 %22 +%36 = OpLoad %13 %35 +%37 = OpAccessChain %18 %5 %22 +OpStore %37 %36 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpStore %3 %23 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpStore %3 %10 +OpBranch %38 +%38 = OpLabel +%39 = OpPhi %8 %10 %20 %53 %52 +OpLoopMerge %54 %52 None +OpBranch %40 +%40 = OpLabel +%41 = OpSLessThan %12 %39 %11 +OpBranchConditional %41 %42 %54 +%42 = OpLabel +%43 = OpAccessChain %18 %5 %39 +%44 = OpLoad %13 %43 +%45 = OpAccessChain %18 %4 %39 +OpStore %45 %44 +%46 = OpIEqual %12 %39 %19 +OpSelectionMerge %47 None +OpBranchConditional %46 %51 %47 +%47 = OpLabel +OpBranch %52 +%51 = OpLabel +OpBranch %52 +%52 = OpLabel +%53 = OpIAdd %8 %39 %19 +OpStore %3 %53 +OpBranch %38 +%54 = OpLabel +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %54 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%32 = OpIEqual %12 %22 %19 +OpSelectionMerge %33 None +OpBranchConditional %32 %34 %33 +%34 = OpLabel +OpBranch %24 +%33 = OpLabel +%35 = OpAccessChain %18 %4 %22 +%36 = OpLoad %13 %35 +%37 = OpAccessChain %18 %5 %22 +OpStore %37 %36 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpStore %3 %23 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +/* +#version 430 +void main(void) { + float A[10]; + float B[10]; + int i = 0; + do { + B[i] = A[i]; + A[i] = B[i]; + ++i; + } while (i < 10); +} + + +Check that this is split into: + int i = 0; + do { + B[i] = A[i]; + ++i; + } while (i < 10); + + i = 0; + do { + A[i] = B[i]; + ++i; + } while (i < 10); + + +*/ +TEST_F(FissionClassTest, FissionDoWhile) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 10 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %8 1 +%18 = OpConstant %8 10 +%19 = OpTypeBool +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %15 Function +%5 = OpVariable %15 Function +OpStore %3 %10 +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %20 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpAccessChain %16 %5 %22 +%28 = OpLoad %11 %27 +%29 = OpAccessChain %16 %4 %22 +OpStore %29 %28 +%30 = OpAccessChain %16 %4 %22 +%31 = OpLoad %11 %30 +%32 = OpAccessChain %16 %5 %22 +OpStore %32 %31 +%23 = OpIAdd %8 %22 %17 +OpStore %3 %23 +OpBranch %24 +%24 = OpLabel +%33 = OpSLessThan %19 %23 %18 +OpBranchConditional %33 %21 %25 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 10 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %8 1 +%18 = OpConstant %8 10 +%19 = OpTypeBool +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %15 Function +%5 = OpVariable %15 Function +OpStore %3 %10 +OpBranch %34 +%34 = OpLabel +%35 = OpPhi %8 %10 %20 %43 %44 +OpLoopMerge %46 %44 None +OpBranch %36 +%36 = OpLabel +%37 = OpAccessChain %16 %5 %35 +%38 = OpLoad %11 %37 +%39 = OpAccessChain %16 %4 %35 +OpStore %39 %38 +%43 = OpIAdd %8 %35 %17 +OpStore %3 %43 +OpBranch %44 +%44 = OpLabel +%45 = OpSLessThan %19 %43 %18 +OpBranchConditional %45 %34 %46 +%46 = OpLabel +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %46 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%30 = OpAccessChain %16 %4 %22 +%31 = OpLoad %11 %30 +%32 = OpAccessChain %16 %5 %22 +OpStore %32 %31 +%23 = OpIAdd %8 %22 %17 +OpStore %3 %23 +OpBranch %24 +%24 = OpLabel +%33 = OpSLessThan %19 %23 %18 +OpBranchConditional %33 %21 %25 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +/* + +#version 430 +void main(void) { + float A[10][10]; + float B[10][10]; + for (int j = 0; j < 10; ++j) { + for (int i = 0; i < 10; ++i) { + B[i][j] = A[i][i]; + A[i][i] = B[i][j + 1]; + } + } +} + + +This loop can't be split because the load B[i][j + 1] is dependent on the store +B[i][j]. + +*/ +TEST_F(FissionClassTest, FissionNestedDependency) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "j" +OpName %4 "i" +OpName %5 "B" +OpName %6 "A" +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypeArray %17 %16 +%19 = OpTypePointer Function %18 +%20 = OpTypePointer Function %14 +%21 = OpConstant %9 1 +%2 = OpFunction %7 None %8 +%22 = OpLabel +%3 = OpVariable %10 Function +%4 = OpVariable %10 Function +%5 = OpVariable %19 Function +%6 = OpVariable %19 Function +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %9 %11 %22 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %13 %24 %12 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpPhi %9 %11 %30 %33 %34 +OpLoopMerge %35 %34 None +OpBranch %36 +%36 = OpLabel +%37 = OpSLessThan %13 %32 %12 +OpBranchConditional %37 %38 %35 +%38 = OpLabel +%39 = OpAccessChain %20 %6 %32 %32 +%40 = OpLoad %14 %39 +%41 = OpAccessChain %20 %5 %32 %24 +OpStore %41 %40 +%42 = OpIAdd %9 %24 %21 +%43 = OpAccessChain %20 %5 %32 %42 +%44 = OpLoad %14 %43 +%45 = OpAccessChain %20 %6 %32 %32 +OpStore %45 %44 +OpBranch %34 +%34 = OpLabel +%33 = OpIAdd %9 %32 %21 +OpBranch %31 +%35 = OpLabel +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %9 %24 %21 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +void main(void) { + float A[10][10]; + float B[10][10]; + for (int j = 0; j < 10; ++j) { + for (int i = 0; i < 10; ++i) { + B[i][i] = A[i][j]; + A[i][j+1] = B[i][i]; + } + } +} + +This loop should not be split as the load A[i][j+1] would be reading a value +written in the store A[i][j] which would be hit before A[i][j+1] if the loops +where split but would not get hit before the read currently. + +*/ +TEST_F(FissionClassTest, FissionNestedDependency2) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "j" +OpName %4 "i" +OpName %5 "B" +OpName %6 "A" +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypeArray %17 %16 +%19 = OpTypePointer Function %18 +%20 = OpTypePointer Function %14 +%21 = OpConstant %9 1 +%2 = OpFunction %7 None %8 +%22 = OpLabel +%3 = OpVariable %10 Function +%4 = OpVariable %10 Function +%5 = OpVariable %19 Function +%6 = OpVariable %19 Function +OpStore %3 %11 +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %9 %11 %22 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %13 %24 %12 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +OpStore %4 %11 +OpBranch %31 +%31 = OpLabel +%32 = OpPhi %9 %11 %30 %33 %34 +OpLoopMerge %35 %34 None +OpBranch %36 +%36 = OpLabel +%37 = OpSLessThan %13 %32 %12 +OpBranchConditional %37 %38 %35 +%38 = OpLabel +%39 = OpAccessChain %20 %6 %32 %24 +%40 = OpLoad %14 %39 +%41 = OpAccessChain %20 %5 %32 %32 +OpStore %41 %40 +%42 = OpIAdd %9 %24 %21 +%43 = OpAccessChain %20 %5 %32 %32 +%44 = OpLoad %14 %43 +%45 = OpAccessChain %20 %6 %32 %42 +OpStore %45 %44 +OpBranch %34 +%34 = OpLabel +%33 = OpIAdd %9 %32 %21 +OpStore %4 %33 +OpBranch %31 +%35 = OpLabel +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %9 %24 %21 +OpStore %3 %25 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +void main(void) { + float A[10][10]; + float B[10][10]; + for (int j = 0; j < 10; ++j) { + for (int i = 0; i < 10; ++i) { + B[i][j] = A[i][j]; + A[i][j] = B[i][j]; + } + for (int i = 0; i < 10; ++i) { + B[i][j] = A[i][j]; + A[i][j] = B[i][j]; + } + } +} + + + +Should be split into: + +for (int j = 0; j < 10; ++j) { + for (int i = 0; i < 10; ++i) + B[i][j] = A[i][j]; + for (int i = 0; i < 10; ++i) + A[i][j] = B[i][j]; + for (int i = 0; i < 10; ++i) + B[i][j] = A[i][j]; + for (int i = 0; i < 10; ++i) + A[i][j] = B[i][j]; +*/ +TEST_F(FissionClassTest, FissionMultipleLoopsNested) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "j" + OpName %4 "i" + OpName %5 "B" + OpName %6 "A" + OpName %7 "i" + %8 = OpTypeVoid + %9 = OpTypeFunction %8 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %12 = OpConstant %10 0 + %13 = OpConstant %10 10 + %14 = OpTypeBool + %15 = OpTypeFloat 32 + %16 = OpTypeInt 32 0 + %17 = OpConstant %16 10 + %18 = OpTypeArray %15 %17 + %19 = OpTypeArray %18 %17 + %20 = OpTypePointer Function %19 + %21 = OpTypePointer Function %15 + %22 = OpConstant %10 1 + %2 = OpFunction %8 None %9 + %23 = OpLabel + %3 = OpVariable %11 Function + %4 = OpVariable %11 Function + %5 = OpVariable %20 Function + %6 = OpVariable %20 Function + %7 = OpVariable %11 Function + OpStore %3 %12 + OpBranch %24 + %24 = OpLabel + %25 = OpPhi %10 %12 %23 %26 %27 + OpLoopMerge %28 %27 None + OpBranch %29 + %29 = OpLabel + %30 = OpSLessThan %14 %25 %13 + OpBranchConditional %30 %31 %28 + %31 = OpLabel + OpStore %4 %12 + OpBranch %32 + %32 = OpLabel + %33 = OpPhi %10 %12 %31 %34 %35 + OpLoopMerge %36 %35 None + OpBranch %37 + %37 = OpLabel + %38 = OpSLessThan %14 %33 %13 + OpBranchConditional %38 %39 %36 + %39 = OpLabel + %40 = OpAccessChain %21 %6 %33 %25 + %41 = OpLoad %15 %40 + %42 = OpAccessChain %21 %5 %33 %25 + OpStore %42 %41 + %43 = OpAccessChain %21 %5 %33 %25 + %44 = OpLoad %15 %43 + %45 = OpAccessChain %21 %6 %33 %25 + OpStore %45 %44 + OpBranch %35 + %35 = OpLabel + %34 = OpIAdd %10 %33 %22 + OpStore %4 %34 + OpBranch %32 + %36 = OpLabel + OpStore %7 %12 + OpBranch %46 + %46 = OpLabel + %47 = OpPhi %10 %12 %36 %48 %49 + OpLoopMerge %50 %49 None + OpBranch %51 + %51 = OpLabel + %52 = OpSLessThan %14 %47 %13 + OpBranchConditional %52 %53 %50 + %53 = OpLabel + %54 = OpAccessChain %21 %6 %47 %25 + %55 = OpLoad %15 %54 + %56 = OpAccessChain %21 %5 %47 %25 + OpStore %56 %55 + %57 = OpAccessChain %21 %5 %47 %25 + %58 = OpLoad %15 %57 + %59 = OpAccessChain %21 %6 %47 %25 + OpStore %59 %58 + OpBranch %49 + %49 = OpLabel + %48 = OpIAdd %10 %47 %22 + OpStore %7 %48 + OpBranch %46 + %50 = OpLabel + OpBranch %27 + %27 = OpLabel + %26 = OpIAdd %10 %25 %22 + OpStore %3 %26 + OpBranch %24 + %28 = OpLabel + OpReturn + OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "j" +OpName %4 "i" +OpName %5 "B" +OpName %6 "A" +OpName %7 "i" +%8 = OpTypeVoid +%9 = OpTypeFunction %8 +%10 = OpTypeInt 32 1 +%11 = OpTypePointer Function %10 +%12 = OpConstant %10 0 +%13 = OpConstant %10 10 +%14 = OpTypeBool +%15 = OpTypeFloat 32 +%16 = OpTypeInt 32 0 +%17 = OpConstant %16 10 +%18 = OpTypeArray %15 %17 +%19 = OpTypeArray %18 %17 +%20 = OpTypePointer Function %19 +%21 = OpTypePointer Function %15 +%22 = OpConstant %10 1 +%2 = OpFunction %8 None %9 +%23 = OpLabel +%3 = OpVariable %11 Function +%4 = OpVariable %11 Function +%5 = OpVariable %20 Function +%6 = OpVariable %20 Function +%7 = OpVariable %11 Function +OpStore %3 %12 +OpBranch %24 +%24 = OpLabel +%25 = OpPhi %10 %12 %23 %26 %27 +OpLoopMerge %28 %27 None +OpBranch %29 +%29 = OpLabel +%30 = OpSLessThan %14 %25 %13 +OpBranchConditional %30 %31 %28 +%31 = OpLabel +OpStore %4 %12 +OpBranch %60 +%60 = OpLabel +%61 = OpPhi %10 %12 %31 %72 %71 +OpLoopMerge %73 %71 None +OpBranch %62 +%62 = OpLabel +%63 = OpSLessThan %14 %61 %13 +OpBranchConditional %63 %64 %73 +%64 = OpLabel +%65 = OpAccessChain %21 %6 %61 %25 +%66 = OpLoad %15 %65 +%67 = OpAccessChain %21 %5 %61 %25 +OpStore %67 %66 +OpBranch %71 +%71 = OpLabel +%72 = OpIAdd %10 %61 %22 +OpStore %4 %72 +OpBranch %60 +%73 = OpLabel +OpBranch %32 +%32 = OpLabel +%33 = OpPhi %10 %12 %73 %34 %35 +OpLoopMerge %36 %35 None +OpBranch %37 +%37 = OpLabel +%38 = OpSLessThan %14 %33 %13 +OpBranchConditional %38 %39 %36 +%39 = OpLabel +%43 = OpAccessChain %21 %5 %33 %25 +%44 = OpLoad %15 %43 +%45 = OpAccessChain %21 %6 %33 %25 +OpStore %45 %44 +OpBranch %35 +%35 = OpLabel +%34 = OpIAdd %10 %33 %22 +OpStore %4 %34 +OpBranch %32 +%36 = OpLabel +OpStore %7 %12 +OpBranch %74 +%74 = OpLabel +%75 = OpPhi %10 %12 %36 %86 %85 +OpLoopMerge %87 %85 None +OpBranch %76 +%76 = OpLabel +%77 = OpSLessThan %14 %75 %13 +OpBranchConditional %77 %78 %87 +%78 = OpLabel +%79 = OpAccessChain %21 %6 %75 %25 +%80 = OpLoad %15 %79 +%81 = OpAccessChain %21 %5 %75 %25 +OpStore %81 %80 +OpBranch %85 +%85 = OpLabel +%86 = OpIAdd %10 %75 %22 +OpStore %7 %86 +OpBranch %74 +%87 = OpLabel +OpBranch %46 +%46 = OpLabel +%47 = OpPhi %10 %12 %87 %48 %49 +OpLoopMerge %50 %49 None +OpBranch %51 +%51 = OpLabel +%52 = OpSLessThan %14 %47 %13 +OpBranchConditional %52 %53 %50 +%53 = OpLabel +%57 = OpAccessChain %21 %5 %47 %25 +%58 = OpLoad %15 %57 +%59 = OpAccessChain %21 %6 %47 %25 +OpStore %59 %58 +OpBranch %49 +%49 = OpLabel +%48 = OpIAdd %10 %47 %22 +OpStore %7 %48 +OpBranch %46 +%50 = OpLabel +OpBranch %27 +%27 = OpLabel +%26 = OpIAdd %10 %25 %22 +OpStore %3 %26 +OpBranch %24 +%28 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + const Function* function = spvtest::GetFunction(module, 2); + LoopDescriptor& pre_pass_descriptor = *context->GetLoopDescriptor(function); + EXPECT_EQ(pre_pass_descriptor.NumLoops(), 3u); + EXPECT_EQ(pre_pass_descriptor.pre_begin()->NumImmediateChildren(), 2u); + + // Test that the pass transforms the ir into the expected output. + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); + + // Test that the loop descriptor is correctly maintained and updated by the + // pass. + LoopFissionPass loop_fission; + loop_fission.SetContextForTesting(context.get()); + loop_fission.Process(); + + function = spvtest::GetFunction(module, 2); + LoopDescriptor& post_pass_descriptor = *context->GetLoopDescriptor(function); + EXPECT_EQ(post_pass_descriptor.NumLoops(), 5u); + EXPECT_EQ(post_pass_descriptor.pre_begin()->NumImmediateChildren(), 4u); +} + +/* +#version 430 +void main(void) { + float A[10][10]; + float B[10][10]; + for (int i = 0; i < 10; ++i) { + B[i][i] = A[i][i]; + A[i][i] = B[i][i]; + } + for (int i = 0; i < 10; ++i) { + B[i][i] = A[i][i]; + A[i][i] = B[i][i] + } +} + + + +Should be split into: + + for (int i = 0; i < 10; ++i) + B[i][i] = A[i][i]; + for (int i = 0; i < 10; ++i) + A[i][i] = B[i][i]; + for (int i = 0; i < 10; ++i) + B[i][i] = A[i][i]; + for (int i = 0; i < 10; ++i) + A[i][i] = B[i][i]; +*/ +TEST_F(FissionClassTest, FissionMultipleLoops) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "B" + OpName %5 "A" + OpName %6 "i" + %7 = OpTypeVoid + %8 = OpTypeFunction %7 + %9 = OpTypeInt 32 1 + %10 = OpTypePointer Function %9 + %11 = OpConstant %9 0 + %12 = OpConstant %9 10 + %13 = OpTypeBool + %14 = OpTypeFloat 32 + %15 = OpTypeInt 32 0 + %16 = OpConstant %15 10 + %17 = OpTypeArray %14 %16 + %18 = OpTypePointer Function %17 + %19 = OpTypePointer Function %14 + %20 = OpConstant %9 1 + %2 = OpFunction %7 None %8 + %21 = OpLabel + %3 = OpVariable %10 Function + %4 = OpVariable %18 Function + %5 = OpVariable %18 Function + %6 = OpVariable %10 Function + OpStore %3 %11 + OpBranch %22 + %22 = OpLabel + %23 = OpPhi %9 %11 %21 %24 %25 + OpLoopMerge %26 %25 None + OpBranch %27 + %27 = OpLabel + %28 = OpSLessThan %13 %23 %12 + OpBranchConditional %28 %29 %26 + %29 = OpLabel + %30 = OpAccessChain %19 %5 %23 + %31 = OpLoad %14 %30 + %32 = OpAccessChain %19 %4 %23 + OpStore %32 %31 + %33 = OpAccessChain %19 %4 %23 + %34 = OpLoad %14 %33 + %35 = OpAccessChain %19 %5 %23 + OpStore %35 %34 + OpBranch %25 + %25 = OpLabel + %24 = OpIAdd %9 %23 %20 + OpStore %3 %24 + OpBranch %22 + %26 = OpLabel + OpStore %6 %11 + OpBranch %36 + %36 = OpLabel + %37 = OpPhi %9 %11 %26 %38 %39 + OpLoopMerge %40 %39 None + OpBranch %41 + %41 = OpLabel + %42 = OpSLessThan %13 %37 %12 + OpBranchConditional %42 %43 %40 + %43 = OpLabel + %44 = OpAccessChain %19 %5 %37 + %45 = OpLoad %14 %44 + %46 = OpAccessChain %19 %4 %37 + OpStore %46 %45 + %47 = OpAccessChain %19 %4 %37 + %48 = OpLoad %14 %47 + %49 = OpAccessChain %19 %5 %37 + OpStore %49 %48 + OpBranch %39 + %39 = OpLabel + %38 = OpIAdd %9 %37 %20 + OpStore %6 %38 + OpBranch %36 + %40 = OpLabel + OpReturn + OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +OpName %6 "i" +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypePointer Function %17 +%19 = OpTypePointer Function %14 +%20 = OpConstant %9 1 +%2 = OpFunction %7 None %8 +%21 = OpLabel +%3 = OpVariable %10 Function +%4 = OpVariable %18 Function +%5 = OpVariable %18 Function +%6 = OpVariable %10 Function +OpStore %3 %11 +OpBranch %64 +%64 = OpLabel +%65 = OpPhi %9 %11 %21 %76 %75 +OpLoopMerge %77 %75 None +OpBranch %66 +%66 = OpLabel +%67 = OpSLessThan %13 %65 %12 +OpBranchConditional %67 %68 %77 +%68 = OpLabel +%69 = OpAccessChain %19 %5 %65 +%70 = OpLoad %14 %69 +%71 = OpAccessChain %19 %4 %65 +OpStore %71 %70 +OpBranch %75 +%75 = OpLabel +%76 = OpIAdd %9 %65 %20 +OpStore %3 %76 +OpBranch %64 +%77 = OpLabel +OpBranch %22 +%22 = OpLabel +%23 = OpPhi %9 %11 %77 %24 %25 +OpLoopMerge %26 %25 None +OpBranch %27 +%27 = OpLabel +%28 = OpSLessThan %13 %23 %12 +OpBranchConditional %28 %29 %26 +%29 = OpLabel +%33 = OpAccessChain %19 %4 %23 +%34 = OpLoad %14 %33 +%35 = OpAccessChain %19 %5 %23 +OpStore %35 %34 +OpBranch %25 +%25 = OpLabel +%24 = OpIAdd %9 %23 %20 +OpStore %3 %24 +OpBranch %22 +%26 = OpLabel +OpStore %6 %11 +OpBranch %50 +%50 = OpLabel +%51 = OpPhi %9 %11 %26 %62 %61 +OpLoopMerge %63 %61 None +OpBranch %52 +%52 = OpLabel +%53 = OpSLessThan %13 %51 %12 +OpBranchConditional %53 %54 %63 +%54 = OpLabel +%55 = OpAccessChain %19 %5 %51 +%56 = OpLoad %14 %55 +%57 = OpAccessChain %19 %4 %51 +OpStore %57 %56 +OpBranch %61 +%61 = OpLabel +%62 = OpIAdd %9 %51 %20 +OpStore %6 %62 +OpBranch %50 +%63 = OpLabel +OpBranch %36 +%36 = OpLabel +%37 = OpPhi %9 %11 %63 %38 %39 +OpLoopMerge %40 %39 None +OpBranch %41 +%41 = OpLabel +%42 = OpSLessThan %13 %37 %12 +OpBranchConditional %42 %43 %40 +%43 = OpLabel +%47 = OpAccessChain %19 %4 %37 +%48 = OpLoad %14 %47 +%49 = OpAccessChain %19 %5 %37 +OpStore %49 %48 +OpBranch %39 +%39 = OpLabel +%38 = OpIAdd %9 %37 %20 +OpStore %6 %38 +OpBranch %36 +%40 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); + + const Function* function = spvtest::GetFunction(module, 2); + LoopDescriptor& pre_pass_descriptor = *context->GetLoopDescriptor(function); + EXPECT_EQ(pre_pass_descriptor.NumLoops(), 2u); + EXPECT_EQ(pre_pass_descriptor.pre_begin()->NumImmediateChildren(), 0u); + + // Test that the pass transforms the ir into the expected output. + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); + + // Test that the loop descriptor is correctly maintained and updated by the + // pass. + LoopFissionPass loop_fission; + loop_fission.SetContextForTesting(context.get()); + loop_fission.Process(); + + function = spvtest::GetFunction(module, 2); + LoopDescriptor& post_pass_descriptor = *context->GetLoopDescriptor(function); + EXPECT_EQ(post_pass_descriptor.NumLoops(), 4u); + EXPECT_EQ(post_pass_descriptor.pre_begin()->NumImmediateChildren(), 0u); +} + +/* +#version 430 +int foo() { return 1; } +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; ++i) { + B[i] = A[i]; + foo(); + A[i] = B[i]; + } +} + +This should not be split as it has a function call in it so we can't determine +if it has side effects. +*/ +TEST_F(FissionClassTest, FissionFunctionCall) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "foo(" +OpName %4 "i" +OpName %5 "B" +OpName %6 "A" +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypeFunction %9 +%11 = OpConstant %9 1 +%12 = OpTypePointer Function %9 +%13 = OpConstant %9 0 +%14 = OpConstant %9 10 +%15 = OpTypeBool +%16 = OpTypeFloat 32 +%17 = OpTypeInt 32 0 +%18 = OpConstant %17 10 +%19 = OpTypeArray %16 %18 +%20 = OpTypePointer Function %19 +%21 = OpTypePointer Function %16 +%2 = OpFunction %7 None %8 +%22 = OpLabel +%4 = OpVariable %12 Function +%5 = OpVariable %20 Function +%6 = OpVariable %20 Function +OpStore %4 %13 +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %9 %13 %22 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %15 %24 %14 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +%31 = OpAccessChain %21 %6 %24 +%32 = OpLoad %16 %31 +%33 = OpAccessChain %21 %5 %24 +OpStore %33 %32 +%34 = OpFunctionCall %9 %3 +%35 = OpAccessChain %21 %5 %24 +%36 = OpLoad %16 %35 +%37 = OpAccessChain %21 %6 %24 +OpStore %37 %36 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %9 %24 %11 +OpStore %4 %25 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +%3 = OpFunction %9 None %10 +%38 = OpLabel +OpReturnValue %11 +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; ++i) { + switch (i) { + case 1: + B[i] = A[i]; + break; + default: + A[i] = B[i]; + } + } +} + +This should be split into: + for (int i = 0; i < 10; ++i) { + switch (i) { + case 1: + break; + default: + A[i] = B[i]; + } + } + + for (int i = 0; i < 10; ++i) { + switch (i) { + case 1: + B[i] = A[i]; + break; + default: + break; + } + } + +*/ +TEST_F(FissionClassTest, FissionSwitchStatement) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "B" + OpName %5 "A" + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpTypeFloat 32 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 10 + %16 = OpTypeArray %13 %15 + %17 = OpTypePointer Function %16 + %18 = OpTypePointer Function %13 + %19 = OpConstant %8 1 + %2 = OpFunction %6 None %7 + %20 = OpLabel + %3 = OpVariable %9 Function + %4 = OpVariable %17 Function + %5 = OpVariable %17 Function + OpStore %3 %10 + OpBranch %21 + %21 = OpLabel + %22 = OpPhi %8 %10 %20 %23 %24 + OpLoopMerge %25 %24 None + OpBranch %26 + %26 = OpLabel + %27 = OpSLessThan %12 %22 %11 + OpBranchConditional %27 %28 %25 + %28 = OpLabel + OpSelectionMerge %29 None + OpSwitch %22 %30 1 %31 + %30 = OpLabel + %32 = OpAccessChain %18 %4 %22 + %33 = OpLoad %13 %32 + %34 = OpAccessChain %18 %5 %22 + OpStore %34 %33 + OpBranch %29 + %31 = OpLabel + %35 = OpAccessChain %18 %5 %22 + %36 = OpLoad %13 %35 + %37 = OpAccessChain %18 %4 %22 + OpStore %37 %36 + OpBranch %29 + %29 = OpLabel + OpBranch %24 + %24 = OpLabel + %23 = OpIAdd %8 %22 %19 + OpStore %3 %23 + OpBranch %21 + %25 = OpLabel + OpReturn + OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpStore %3 %10 +OpBranch %38 +%38 = OpLabel +%39 = OpPhi %8 %10 %20 %53 %52 +OpLoopMerge %54 %52 None +OpBranch %40 +%40 = OpLabel +%41 = OpSLessThan %12 %39 %11 +OpBranchConditional %41 %42 %54 +%42 = OpLabel +OpSelectionMerge %51 None +OpSwitch %39 %47 1 %43 +%43 = OpLabel +OpBranch %51 +%47 = OpLabel +%48 = OpAccessChain %18 %4 %39 +%49 = OpLoad %13 %48 +%50 = OpAccessChain %18 %5 %39 +OpStore %50 %49 +OpBranch %51 +%51 = OpLabel +OpBranch %52 +%52 = OpLabel +%53 = OpIAdd %8 %39 %19 +OpStore %3 %53 +OpBranch %38 +%54 = OpLabel +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %54 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +OpSelectionMerge %29 None +OpSwitch %22 %30 1 %31 +%30 = OpLabel +OpBranch %29 +%31 = OpLabel +%35 = OpAccessChain %18 %5 %22 +%36 = OpLoad %13 %35 +%37 = OpAccessChain %18 %4 %22 +OpStore %37 %36 +OpBranch %29 +%29 = OpLabel +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpStore %3 %23 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/nested_loops.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/nested_loops.cpp index 480a28040..651cdef44 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/nested_loops.cpp +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/nested_loops.cpp @@ -12,26 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include #include #include #include -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" - -#include "opt/iterator.h" -#include "opt/loop_descriptor.h" -#include "opt/pass.h" -#include "opt/tree_iterator.h" +#include "gmock/gmock.h" +#include "source/opt/iterator.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" +#include "source/opt/tree_iterator.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; bool Validate(const std::vector& bin) { @@ -150,14 +149,14 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(f); + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); EXPECT_EQ(ld.NumLoops(), 3u); @@ -166,7 +165,7 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { // Not a loop header. EXPECT_EQ(ld[20], nullptr); - ir::Loop& parent_loop = *ld[21]; + Loop& parent_loop = *ld[21]; EXPECT_TRUE(parent_loop.HasNestedLoops()); EXPECT_FALSE(parent_loop.IsNested()); EXPECT_EQ(parent_loop.GetDepth(), 1u); @@ -175,7 +174,7 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { EXPECT_EQ(parent_loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 23)); EXPECT_EQ(parent_loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 22)); - ir::Loop& child_loop_1 = *ld[28]; + Loop& child_loop_1 = *ld[28]; EXPECT_FALSE(child_loop_1.HasNestedLoops()); EXPECT_TRUE(child_loop_1.IsNested()); EXPECT_EQ(child_loop_1.GetDepth(), 2u); @@ -184,7 +183,7 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { EXPECT_EQ(child_loop_1.GetLatchBlock(), spvtest::GetBasicBlock(f, 30)); EXPECT_EQ(child_loop_1.GetMergeBlock(), spvtest::GetBasicBlock(f, 29)); - ir::Loop& child_loop_2 = *ld[37]; + Loop& child_loop_2 = *ld[37]; EXPECT_FALSE(child_loop_2.HasNestedLoops()); EXPECT_TRUE(child_loop_2.IsNested()); EXPECT_EQ(child_loop_2.GetDepth(), 2u); @@ -194,7 +193,7 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { EXPECT_EQ(child_loop_2.GetMergeBlock(), spvtest::GetBasicBlock(f, 38)); } -static void CheckLoopBlocks(ir::Loop* loop, +static void CheckLoopBlocks(Loop* loop, std::unordered_set* expected_ids) { SCOPED_TRACE("Check loop " + std::to_string(loop->GetHeaderBlock()->id())); for (uint32_t bb_id : loop->GetBlocks()) { @@ -336,14 +335,14 @@ TEST_F(PassClassTest, TripleNestedLoop) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(f); + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); EXPECT_EQ(ld.NumLoops(), 4u); @@ -360,7 +359,7 @@ TEST_F(PassClassTest, TripleNestedLoop) { std::unordered_set basic_block_in_loop = { {23, 26, 29, 30, 33, 36, 40, 41, 44, 47, 43, 42, 39, 50, 53, 56, 52, 51, 32, 31, 25}}; - ir::Loop* loop = ld[23]; + Loop* loop = ld[23]; CheckLoopBlocks(loop, &basic_block_in_loop); EXPECT_TRUE(loop->HasNestedLoops()); @@ -378,7 +377,7 @@ TEST_F(PassClassTest, TripleNestedLoop) { { std::unordered_set basic_block_in_loop = { {30, 33, 36, 40, 41, 44, 47, 43, 42, 39, 50, 53, 56, 52, 51, 32}}; - ir::Loop* loop = ld[30]; + Loop* loop = ld[30]; CheckLoopBlocks(loop, &basic_block_in_loop); EXPECT_TRUE(loop->HasNestedLoops()); @@ -395,7 +394,7 @@ TEST_F(PassClassTest, TripleNestedLoop) { { std::unordered_set basic_block_in_loop = {{41, 44, 47, 43}}; - ir::Loop* loop = ld[41]; + Loop* loop = ld[41]; CheckLoopBlocks(loop, &basic_block_in_loop); EXPECT_FALSE(loop->HasNestedLoops()); @@ -412,7 +411,7 @@ TEST_F(PassClassTest, TripleNestedLoop) { { std::unordered_set basic_block_in_loop = {{50, 53, 56, 52}}; - ir::Loop* loop = ld[50]; + Loop* loop = ld[50]; CheckLoopBlocks(loop, &basic_block_in_loop); EXPECT_FALSE(loop->HasNestedLoops()); @@ -429,11 +428,10 @@ TEST_F(PassClassTest, TripleNestedLoop) { // Make sure LoopDescriptor gives us the inner most loop when we query for // loops. - for (const ir::BasicBlock& bb : *f) { - if (ir::Loop* loop = ld[&bb]) { - for (ir::Loop& sub_loop : - ir::make_range(++opt::TreeDFIterator(loop), - opt::TreeDFIterator())) { + for (const BasicBlock& bb : *f) { + if (Loop* loop = ld[&bb]) { + for (Loop& sub_loop : + make_range(++TreeDFIterator(loop), TreeDFIterator())) { EXPECT_FALSE(sub_loop.IsInsideLoop(bb.id())); } } @@ -560,19 +558,19 @@ TEST_F(PassClassTest, LoopParentTest) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(f); + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); EXPECT_EQ(ld.NumLoops(), 4u); { - ir::Loop& loop = *ld[22]; + Loop& loop = *ld[22]; EXPECT_TRUE(loop.HasNestedLoops()); EXPECT_FALSE(loop.IsNested()); EXPECT_EQ(loop.GetDepth(), 1u); @@ -580,7 +578,7 @@ TEST_F(PassClassTest, LoopParentTest) { } { - ir::Loop& loop = *ld[29]; + Loop& loop = *ld[29]; EXPECT_TRUE(loop.HasNestedLoops()); EXPECT_TRUE(loop.IsNested()); EXPECT_EQ(loop.GetDepth(), 2u); @@ -588,7 +586,7 @@ TEST_F(PassClassTest, LoopParentTest) { } { - ir::Loop& loop = *ld[36]; + Loop& loop = *ld[36]; EXPECT_FALSE(loop.HasNestedLoops()); EXPECT_TRUE(loop.IsNested()); EXPECT_EQ(loop.GetDepth(), 3u); @@ -596,7 +594,7 @@ TEST_F(PassClassTest, LoopParentTest) { } { - ir::Loop& loop = *ld[47]; + Loop& loop = *ld[47]; EXPECT_FALSE(loop.HasNestedLoops()); EXPECT_TRUE(loop.IsNested()); EXPECT_EQ(loop.GetDepth(), 2u); @@ -701,21 +699,21 @@ TEST_F(PassClassTest, CreatePreheaderTest) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(f); + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); // No invalidation of the cfg should occur during this test. - ir::CFG* cfg = context->cfg(); + CFG* cfg = context->cfg(); EXPECT_EQ(ld.NumLoops(), 3u); { - ir::Loop& loop = *ld[16]; + Loop& loop = *ld[16]; EXPECT_TRUE(loop.HasNestedLoops()); EXPECT_FALSE(loop.IsNested()); EXPECT_EQ(loop.GetDepth(), 1u); @@ -723,7 +721,7 @@ TEST_F(PassClassTest, CreatePreheaderTest) { } { - ir::Loop& loop = *ld[33]; + Loop& loop = *ld[33]; EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr); EXPECT_NE(loop.GetOrCreatePreHeaderBlock(), nullptr); // Make sure the loop descriptor was properly updated. @@ -736,12 +734,11 @@ TEST_F(PassClassTest, CreatePreheaderTest) { EXPECT_TRUE(pred_set.count(30)); EXPECT_TRUE(pred_set.count(31)); // Check the phi instructions. - loop.GetPreHeaderBlock()->ForEachPhiInst( - [&pred_set](ir::Instruction* phi) { - for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) { - EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i))); - } - }); + loop.GetPreHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) { + for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) { + EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i))); + } + }); } { const std::vector& preds = @@ -751,7 +748,7 @@ TEST_F(PassClassTest, CreatePreheaderTest) { EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id())); EXPECT_TRUE(pred_set.count(35)); // Check the phi instructions. - loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](ir::Instruction* phi) { + loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) { for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) { EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i))); } @@ -760,14 +757,14 @@ TEST_F(PassClassTest, CreatePreheaderTest) { } { - ir::Loop& loop = *ld[41]; + Loop& loop = *ld[41]; EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr); EXPECT_NE(loop.GetOrCreatePreHeaderBlock(), nullptr); EXPECT_EQ(ld[loop.GetPreHeaderBlock()], nullptr); EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id()).size(), 1u); EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id())[0], 25u); // Check the phi instructions. - loop.GetPreHeaderBlock()->ForEachPhiInst([](ir::Instruction* phi) { + loop.GetPreHeaderBlock()->ForEachPhiInst([](Instruction* phi) { EXPECT_EQ(phi->NumInOperands(), 2u); EXPECT_EQ(phi->GetSingleWordInOperand(1), 25u); }); @@ -779,7 +776,7 @@ TEST_F(PassClassTest, CreatePreheaderTest) { EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id())); EXPECT_TRUE(pred_set.count(44)); // Check the phi instructions. - loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](ir::Instruction* phi) { + loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) { for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) { EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i))); } @@ -794,3 +791,5 @@ TEST_F(PassClassTest, CreatePreheaderTest) { } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/peeling.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/peeling.cpp index a88e8dd35..e5db20b40 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/peeling.cpp +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/peeling.cpp @@ -12,21 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/ir_builder.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_peeling.h" +#include "test/opt/pass_fixture.h" #ifdef SPIRV_EFFCEE #include "effcee/effcee.h" #endif -#include "../pass_fixture.h" -#include "opt/ir_builder.h" -#include "opt/loop_descriptor.h" -#include "opt/loop_peeling.h" - +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using PeelingTest = PassTest<::testing::Test>; bool Validate(const std::vector& bin) { @@ -41,7 +44,10 @@ bool Validate(const std::vector& bin) { return error == 0; } -void Match(const std::string& checks, ir::IRContext* context) { +void Match(const std::string& checks, IRContext* context) { + // Silence unused warnings with !defined(SPIRV_EFFCE) + (void)checks; + std::vector bin; context->module()->ToBinary(&bin, true); EXPECT_TRUE(Validate(bin)); @@ -56,7 +62,9 @@ void Match(const std::string& checks, ir::IRContext* context) { EXPECT_EQ(effcee::Result::Status::Ok, match_result.status()) << match_result.message() << "\nChecking result:\n" << assembly; -#endif // ! SPIRV_EFFCEE +#else // ! SPIRV_EFFCEE + (void)checks; +#endif } /* @@ -114,27 +122,27 @@ TEST_F(PeelingTest, CannotPeel) { // representing the loop count, if equals to 0, then the function build a 10 // constant as loop count. auto test_cannot_peel = [](const std::string& text, uint32_t loop_count_id) { - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - ir::Function& f = *module->begin(); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); EXPECT_EQ(ld.NumLoops(), 1u); - ir::Instruction* loop_count = nullptr; + Instruction* loop_count = nullptr; if (loop_count_id) { loop_count = context->get_def_use_mgr()->GetDef(loop_count_id); } else { - opt::InstructionBuilder builder(context.get(), &*f.begin()); + InstructionBuilder builder(context.get(), &*f.begin()); // Exit condition. loop_count = builder.Add32BitSignedIntegerConstant(10); } - opt::LoopPeeling peel(context.get(), &*ld.begin(), loop_count); + LoopPeeling peel(&*ld.begin(), loop_count); EXPECT_FALSE(peel.CanPeelLoop()); }; { @@ -480,22 +488,22 @@ TEST_F(PeelingTest, SimplePeeling) { { SCOPED_TRACE("Peel before"); - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - ir::Function& f = *module->begin(); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); EXPECT_EQ(ld.NumLoops(), 1u); - opt::InstructionBuilder builder(context.get(), &*f.begin()); + InstructionBuilder builder(context.get(), &*f.begin()); // Exit condition. - ir::Instruction* ten_cst = builder.Add32BitSignedIntegerConstant(10); + Instruction* ten_cst = builder.Add32BitSignedIntegerConstant(10); - opt::LoopPeeling peel(context.get(), &*ld.begin(), ten_cst); + LoopPeeling peel(&*ld.begin(), ten_cst); EXPECT_TRUE(peel.CanPeelLoop()); peel.PeelBefore(2); @@ -534,22 +542,22 @@ CHECK-NEXT: OpLoopMerge { SCOPED_TRACE("Peel after"); - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - ir::Function& f = *module->begin(); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); EXPECT_EQ(ld.NumLoops(), 1u); - opt::InstructionBuilder builder(context.get(), &*f.begin()); + InstructionBuilder builder(context.get(), &*f.begin()); // Exit condition. - ir::Instruction* ten_cst = builder.Add32BitSignedIntegerConstant(10); + Instruction* ten_cst = builder.Add32BitSignedIntegerConstant(10); - opt::LoopPeeling peel(context.get(), &*ld.begin(), ten_cst); + LoopPeeling peel(&*ld.begin(), ten_cst); EXPECT_TRUE(peel.CanPeelLoop()); peel.PeelAfter(2); @@ -580,6 +588,114 @@ CHECK: [[AFTER_LOOP]] = OpLabel CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[TMP]] [[IF_MERGE]] CHECK-NEXT: OpLoopMerge +)"; + + Match(check, context.get()); + } + + // Same as above, but reuse the induction variable. + // Peel before. + { + SCOPED_TRACE("Peel before with IV reuse"); + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + + EXPECT_EQ(ld.NumLoops(), 1u); + + InstructionBuilder builder(context.get(), &*f.begin()); + // Exit condition. + Instruction* ten_cst = builder.Add32BitSignedIntegerConstant(10); + + LoopPeeling peel(&*ld.begin(), ten_cst, + context->get_def_use_mgr()->GetDef(22)); + EXPECT_TRUE(peel.CanPeelLoop()); + peel.PeelBefore(2); + + const std::string check = R"( +CHECK: [[CST_TEN:%\w+]] = OpConstant {{%\w+}} 10 +CHECK: [[CST_TWO:%\w+]] = OpConstant {{%\w+}} 2 +CHECK: OpFunction +CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel +CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpSLessThan {{%\w+}} [[CST_TWO]] [[CST_TEN]] +CHECK-NEXT: [[LOOP_COUNT:%\w+]] = OpSelect {{%\w+}} [[MIN_LOOP_COUNT]] [[CST_TWO]] [[CST_TEN]] +CHECK: [[BEFORE_LOOP:%\w+]] = OpLabel +CHECK-NEXT: [[i:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE:%\w+]] +CHECK-NEXT: OpLoopMerge [[AFTER_LOOP_PREHEADER:%\w+]] [[BE]] None +CHECK: [[COND_BLOCK:%\w+]] = OpLabel +CHECK-NEXT: OpSLessThan +CHECK-NEXT: [[EXIT_COND:%\w+]] = OpSLessThan {{%\w+}} [[i]] +CHECK-NEXT: OpBranchConditional [[EXIT_COND]] {{%\w+}} [[AFTER_LOOP_PREHEADER]] +CHECK: [[I_1]] = OpIAdd {{%\w+}} [[i]] +CHECK-NEXT: OpBranch [[BEFORE_LOOP]] + +CHECK: [[AFTER_LOOP_PREHEADER]] = OpLabel +CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]] +CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[AFTER_LOOP:%\w+]] [[IF_MERGE]] + +CHECK: [[AFTER_LOOP]] = OpLabel +CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[i]] [[AFTER_LOOP_PREHEADER]] +CHECK-NEXT: OpLoopMerge +)"; + + Match(check, context.get()); + } + + // Peel after. + { + SCOPED_TRACE("Peel after IV reuse"); + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + + EXPECT_EQ(ld.NumLoops(), 1u); + + InstructionBuilder builder(context.get(), &*f.begin()); + // Exit condition. + Instruction* ten_cst = builder.Add32BitSignedIntegerConstant(10); + + LoopPeeling peel(&*ld.begin(), ten_cst, + context->get_def_use_mgr()->GetDef(22)); + EXPECT_TRUE(peel.CanPeelLoop()); + peel.PeelAfter(2); + + const std::string check = R"( +CHECK: OpFunction +CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel +CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpSLessThan {{%\w+}} +CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]] +CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[BEFORE_LOOP:%\w+]] [[IF_MERGE]] +CHECK: [[BEFORE_LOOP]] = OpLabel +CHECK-NEXT: [[I:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE:%\w+]] +CHECK-NEXT: OpLoopMerge [[BEFORE_LOOP_MERGE:%\w+]] [[BE]] None +CHECK: [[COND_BLOCK:%\w+]] = OpLabel +CHECK-NEXT: OpSLessThan +CHECK-NEXT: [[TMP:%\w+]] = OpIAdd {{%\w+}} [[I]] {{%\w+}} +CHECK-NEXT: [[EXIT_COND:%\w+]] = OpSLessThan {{%\w+}} [[TMP]] +CHECK-NEXT: OpBranchConditional [[EXIT_COND]] {{%\w+}} [[BEFORE_LOOP_MERGE]] +CHECK: [[I_1]] = OpIAdd {{%\w+}} [[I]] +CHECK-NEXT: OpBranch [[BEFORE_LOOP]] + +CHECK: [[IF_MERGE]] = OpLabel +CHECK-NEXT: [[TMP:%\w+]] = OpPhi {{%\w+}} [[I]] [[BEFORE_LOOP_MERGE]] +CHECK-NEXT: OpBranch [[AFTER_LOOP:%\w+]] + +CHECK: [[AFTER_LOOP]] = OpLabel +CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[TMP]] [[IF_MERGE]] +CHECK-NEXT: OpLoopMerge + )"; Match(check, context.get()); @@ -644,21 +760,21 @@ TEST_F(PeelingTest, PeelingUncountable) { { SCOPED_TRACE("Peel before"); - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - ir::Function& f = *module->begin(); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); EXPECT_EQ(ld.NumLoops(), 1u); - ir::Instruction* loop_count = context->get_def_use_mgr()->GetDef(16); + Instruction* loop_count = context->get_def_use_mgr()->GetDef(16); EXPECT_EQ(loop_count->opcode(), SpvOpLoad); - opt::LoopPeeling peel(context.get(), &*ld.begin(), loop_count); + LoopPeeling peel(&*ld.begin(), loop_count); EXPECT_TRUE(peel.CanPeelLoop()); peel.PeelBefore(1); @@ -696,21 +812,21 @@ CHECK-NEXT: OpLoopMerge { SCOPED_TRACE("Peel after"); - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - ir::Function& f = *module->begin(); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); EXPECT_EQ(ld.NumLoops(), 1u); - ir::Instruction* loop_count = context->get_def_use_mgr()->GetDef(16); + Instruction* loop_count = context->get_def_use_mgr()->GetDef(16); EXPECT_EQ(loop_count->opcode(), SpvOpLoad); - opt::LoopPeeling peel(context.get(), &*ld.begin(), loop_count); + LoopPeeling peel(&*ld.begin(), loop_count); EXPECT_TRUE(peel.CanPeelLoop()); peel.PeelAfter(1); @@ -797,21 +913,21 @@ TEST_F(PeelingTest, DoWhilePeeling) { { SCOPED_TRACE("Peel before"); - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - ir::Function& f = *module->begin(); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); EXPECT_EQ(ld.NumLoops(), 1u); - opt::InstructionBuilder builder(context.get(), &*f.begin()); + InstructionBuilder builder(context.get(), &*f.begin()); // Exit condition. - ir::Instruction* ten_cst = builder.Add32BitUnsignedIntegerConstant(10); + Instruction* ten_cst = builder.Add32BitUnsignedIntegerConstant(10); - opt::LoopPeeling peel(context.get(), &*ld.begin(), ten_cst); + LoopPeeling peel(&*ld.begin(), ten_cst); EXPECT_TRUE(peel.CanPeelLoop()); peel.PeelBefore(2); @@ -846,22 +962,22 @@ CHECK-NEXT: OpLoopMerge { SCOPED_TRACE("Peel after"); - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - ir::Function& f = *module->begin(); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); EXPECT_EQ(ld.NumLoops(), 1u); - opt::InstructionBuilder builder(context.get(), &*f.begin()); + InstructionBuilder builder(context.get(), &*f.begin()); // Exit condition. - ir::Instruction* ten_cst = builder.Add32BitUnsignedIntegerConstant(10); + Instruction* ten_cst = builder.Add32BitUnsignedIntegerConstant(10); - opt::LoopPeeling peel(context.get(), &*ld.begin(), ten_cst); + LoopPeeling peel(&*ld.begin(), ten_cst); EXPECT_TRUE(peel.CanPeelLoop()); peel.PeelAfter(2); @@ -969,21 +1085,21 @@ TEST_F(PeelingTest, PeelingLoopWithStore) { { SCOPED_TRACE("Peel before"); - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - ir::Function& f = *module->begin(); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); EXPECT_EQ(ld.NumLoops(), 1u); - ir::Instruction* loop_count = context->get_def_use_mgr()->GetDef(15); + Instruction* loop_count = context->get_def_use_mgr()->GetDef(15); EXPECT_EQ(loop_count->opcode(), SpvOpLoad); - opt::LoopPeeling peel(context.get(), &*ld.begin(), loop_count); + LoopPeeling peel(&*ld.begin(), loop_count); EXPECT_TRUE(peel.CanPeelLoop()); peel.PeelBefore(1); @@ -1021,21 +1137,21 @@ CHECK-NEXT: OpLoopMerge { SCOPED_TRACE("Peel after"); - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - ir::Function& f = *module->begin(); - ir::LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); EXPECT_EQ(ld.NumLoops(), 1u); - ir::Instruction* loop_count = context->get_def_use_mgr()->GetDef(15); + Instruction* loop_count = context->get_def_use_mgr()->GetDef(15); EXPECT_EQ(loop_count->opcode(), SpvOpLoad); - opt::LoopPeeling peel(context.get(), &*ld.begin(), loop_count); + LoopPeeling peel(&*ld.begin(), loop_count); EXPECT_TRUE(peel.CanPeelLoop()); peel.PeelAfter(1); @@ -1073,3 +1189,5 @@ CHECK-NEXT: OpLoopMerge } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/peeling_pass.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/peeling_pass.cpp new file mode 100644 index 000000000..284ad838d --- /dev/null +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/peeling_pass.cpp @@ -0,0 +1,1099 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/ir_builder.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_peeling.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +class PeelingPassTest : public PassTest<::testing::Test> { + public: + // Generic routine to run the loop peeling pass and check + LoopPeelingPass::LoopPeelingStats AssembleAndRunPeelingTest( + const std::string& text_head, const std::string& text_tail, SpvOp opcode, + const std::string& res_id, const std::string& op1, + const std::string& op2) { + std::string opcode_str; + switch (opcode) { + case SpvOpSLessThan: + opcode_str = "OpSLessThan"; + break; + case SpvOpSGreaterThan: + opcode_str = "OpSGreaterThan"; + break; + case SpvOpSLessThanEqual: + opcode_str = "OpSLessThanEqual"; + break; + case SpvOpSGreaterThanEqual: + opcode_str = "OpSGreaterThanEqual"; + break; + case SpvOpIEqual: + opcode_str = "OpIEqual"; + break; + case SpvOpINotEqual: + opcode_str = "OpINotEqual"; + break; + default: + assert(false && "Unhandled"); + break; + } + std::string test_cond = + res_id + " = " + opcode_str + " %bool " + op1 + " " + op2 + "\n"; + + LoopPeelingPass::LoopPeelingStats stats; + SinglePassRunAndDisassemble( + text_head + test_cond + text_tail, true, true, &stats); + + return stats; + } + + // Generic routine to run the loop peeling pass and check + LoopPeelingPass::LoopPeelingStats RunPeelingTest( + const std::string& text_head, const std::string& text_tail, SpvOp opcode, + const std::string& res_id, const std::string& op1, const std::string& op2, + size_t nb_of_loops) { + LoopPeelingPass::LoopPeelingStats stats = AssembleAndRunPeelingTest( + text_head, text_tail, opcode, res_id, op1, op2); + + Function& f = *context()->module()->begin(); + LoopDescriptor& ld = *context()->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), nb_of_loops); + + return stats; + } + + using PeelTraceType = + std::vector>; + + void BuildAndCheckTrace(const std::string& text_head, + const std::string& text_tail, SpvOp opcode, + const std::string& res_id, const std::string& op1, + const std::string& op2, + const PeelTraceType& expected_peel_trace, + size_t expected_nb_of_loops) { + auto stats = RunPeelingTest(text_head, text_tail, opcode, res_id, op1, op2, + expected_nb_of_loops); + + EXPECT_EQ(stats.peeled_loops_.size(), expected_peel_trace.size()); + if (stats.peeled_loops_.size() != expected_peel_trace.size()) { + return; + } + + PeelTraceType::const_iterator expected_trace_it = + expected_peel_trace.begin(); + decltype(stats.peeled_loops_)::const_iterator stats_it = + stats.peeled_loops_.begin(); + + while (expected_trace_it != expected_peel_trace.end()) { + EXPECT_EQ(expected_trace_it->first, std::get<1>(*stats_it)); + EXPECT_EQ(expected_trace_it->second, std::get<2>(*stats_it)); + ++expected_trace_it; + ++stats_it; + } + } +}; + +/* +Test are derivation of the following generated test from the following GLSL + +--eliminate-local-multi-store + +#version 330 core +void main() { + int a = 0; + for(int i = 1; i < 10; i += 2) { + if (i < 3) { + a += 2; + } + } +} + +The condition is interchanged to test < > <= >= == and peel before/after +opportunities. +*/ +TEST_F(PeelingPassTest, PeelingPassBasic) { + const std::string text_head = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + OpName %a "a" + OpName %i "i" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %bool = OpTypeBool + %int_20 = OpConstant %int 20 + %int_19 = OpConstant %int 19 + %int_18 = OpConstant %int 18 + %int_17 = OpConstant %int 17 + %int_16 = OpConstant %int 16 + %int_15 = OpConstant %int 15 + %int_14 = OpConstant %int 14 + %int_13 = OpConstant %int 13 + %int_12 = OpConstant %int 12 + %int_11 = OpConstant %int 11 + %int_10 = OpConstant %int 10 + %int_9 = OpConstant %int 9 + %int_8 = OpConstant %int 8 + %int_7 = OpConstant %int 7 + %int_6 = OpConstant %int 6 + %int_5 = OpConstant %int 5 + %int_4 = OpConstant %int 4 + %int_3 = OpConstant %int 3 + %int_2 = OpConstant %int 2 + %int_1 = OpConstant %int 1 + %int_0 = OpConstant %int 0 + %main = OpFunction %void None %3 + %5 = OpLabel + %a = OpVariable %_ptr_Function_int Function + %i = OpVariable %_ptr_Function_int Function + OpStore %a %int_0 + OpStore %i %int_0 + OpBranch %11 + %11 = OpLabel + %31 = OpPhi %int %int_0 %5 %33 %14 + %32 = OpPhi %int %int_1 %5 %30 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %bool %32 %int_20 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + )"; + const std::string text_tail = R"( + OpSelectionMerge %24 None + OpBranchConditional %22 %23 %24 + %23 = OpLabel + %27 = OpIAdd %int %31 %int_2 + OpStore %a %27 + OpBranch %24 + %24 = OpLabel + %33 = OpPhi %int %31 %12 %27 %23 + OpBranch %14 + %14 = OpLabel + %30 = OpIAdd %int %32 %int_2 + OpStore %i %30 + OpBranch %11 + %13 = OpLabel + OpReturn + OpFunctionEnd + )"; + + auto run_test = [&text_head, &text_tail, this](SpvOp opcode, + const std::string& op1, + const std::string& op2) { + auto stats = + RunPeelingTest(text_head, text_tail, opcode, "%22", op1, op2, 2); + + EXPECT_EQ(stats.peeled_loops_.size(), 1u); + if (stats.peeled_loops_.size() != 1u) + return std::pair{ + LoopPeelingPass::PeelDirection::kNone, 0}; + + return std::pair{ + std::get<1>(*stats.peeled_loops_.begin()), + std::get<2>(*stats.peeled_loops_.begin())}; + }; + + // Test LT + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv < 4"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%32", "%int_4"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 4 > iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%int_4", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before iv < 5"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%32", "%int_5"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 5 > iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%int_5", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv < 16"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%32", "%int_16"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 16 > iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%int_16", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after iv < 17"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%32", "%int_17"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 17 > iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%int_17", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + + // Test GT + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv > 5"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%32", "%int_5"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 5 < iv"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%int_5", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before iv > 4"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%32", "%int_4"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 4 < iv"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%int_4", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv > 16"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%32", "%int_16"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 16 < iv"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%int_16", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after iv > 17"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%32", "%int_17"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 17 < iv"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%int_17", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + + // Test LE + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv <= 4"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%32", "%int_4"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 4 => iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%int_4", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before iv <= 3"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%32", "%int_3"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 3 => iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%int_3", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv <= 16"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%32", "%int_16"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 16 => iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%int_16", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after iv <= 15"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%32", "%int_15"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 15 => iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%int_15", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + + // Test GE + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv >= 5"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%32", "%int_5"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 35 >= iv"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%int_5", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before iv >= 4"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%32", "%int_4"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 4 <= iv"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%int_4", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv >= 17"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%32", "%int_17"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 17 <= iv"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%int_17", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after iv >= 16"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%32", "%int_16"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 16 <= iv"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%int_16", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + + // Test EQ + // Peel before by a factor of 1. + { + SCOPED_TRACE("Peel before iv == 1"); + + std::pair peel_info = + run_test(SpvOpIEqual, "%32", "%int_1"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 1u); + } + { + SCOPED_TRACE("Peel before 1 == iv"); + + std::pair peel_info = + run_test(SpvOpIEqual, "%int_1", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 1u); + } + + // Peel after by a factor of 1. + { + SCOPED_TRACE("Peel after iv == 19"); + + std::pair peel_info = + run_test(SpvOpIEqual, "%32", "%int_19"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 1u); + } + { + SCOPED_TRACE("Peel after 19 == iv"); + + std::pair peel_info = + run_test(SpvOpIEqual, "%int_19", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 1u); + } + + // Test NE + // Peel before by a factor of 1. + { + SCOPED_TRACE("Peel before iv != 1"); + + std::pair peel_info = + run_test(SpvOpINotEqual, "%32", "%int_1"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 1u); + } + { + SCOPED_TRACE("Peel before 1 != iv"); + + std::pair peel_info = + run_test(SpvOpINotEqual, "%int_1", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 1u); + } + + // Peel after by a factor of 1. + { + SCOPED_TRACE("Peel after iv != 19"); + + std::pair peel_info = + run_test(SpvOpINotEqual, "%32", "%int_19"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 1u); + } + { + SCOPED_TRACE("Peel after 19 != iv"); + + std::pair peel_info = + run_test(SpvOpINotEqual, "%int_19", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 1u); + } + + // No peel. + { + SCOPED_TRACE("No Peel: 20 => iv"); + + auto stats = RunPeelingTest(text_head, text_tail, SpvOpSLessThanEqual, + "%22", "%int_20", "%32", 1); + + EXPECT_EQ(stats.peeled_loops_.size(), 0u); + } +} + +/* +Test are derivation of the following generated test from the following GLSL + +--eliminate-local-multi-store + +#version 330 core +void main() { + int a = 0; + for(int i = 0; i < 10; ++i) { + if (i < 3) { + a += 2; + } + if (i < 1) { + a += 2; + } + } +} + +The condition is interchanged to test < > <= >= == and peel before/after +opportunities. +*/ +TEST_F(PeelingPassTest, MultiplePeelingPass) { + const std::string text_head = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + OpName %a "a" + OpName %i "i" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %bool = OpTypeBool + %int_10 = OpConstant %int 10 + %int_9 = OpConstant %int 9 + %int_8 = OpConstant %int 8 + %int_7 = OpConstant %int 7 + %int_6 = OpConstant %int 6 + %int_5 = OpConstant %int 5 + %int_4 = OpConstant %int 4 + %int_3 = OpConstant %int 3 + %int_2 = OpConstant %int 2 + %int_1 = OpConstant %int 1 + %int_0 = OpConstant %int 0 + %main = OpFunction %void None %3 + %5 = OpLabel + %a = OpVariable %_ptr_Function_int Function + %i = OpVariable %_ptr_Function_int Function + OpStore %a %int_0 + OpStore %i %int_0 + OpBranch %11 + %11 = OpLabel + %37 = OpPhi %int %int_0 %5 %40 %14 + %38 = OpPhi %int %int_0 %5 %36 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %bool %38 %int_10 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + )"; + const std::string text_tail = R"( + OpSelectionMerge %24 None + OpBranchConditional %22 %23 %24 + %23 = OpLabel + %27 = OpIAdd %int %37 %int_2 + OpStore %a %27 + OpBranch %24 + %24 = OpLabel + %39 = OpPhi %int %37 %12 %27 %23 + %30 = OpSLessThan %bool %38 %int_1 + OpSelectionMerge %32 None + OpBranchConditional %30 %31 %32 + %31 = OpLabel + %34 = OpIAdd %int %39 %int_2 + OpStore %a %34 + OpBranch %32 + %32 = OpLabel + %40 = OpPhi %int %39 %24 %34 %31 + OpBranch %14 + %14 = OpLabel + %36 = OpIAdd %int %38 %int_1 + OpStore %i %36 + OpBranch %11 + %13 = OpLabel + OpReturn + OpFunctionEnd + )"; + + auto run_test = [&text_head, &text_tail, this]( + SpvOp opcode, const std::string& op1, + const std::string& op2, + const PeelTraceType& expected_peel_trace) { + BuildAndCheckTrace(text_head, text_tail, opcode, "%22", op1, op2, + expected_peel_trace, expected_peel_trace.size() + 1); + }; + + // Test LT + // Peel before by a factor of 3. + { + SCOPED_TRACE("Peel before iv < 3"); + + run_test(SpvOpSLessThan, "%38", "%int_3", + {{LoopPeelingPass::PeelDirection::kBefore, 3u}}); + } + { + SCOPED_TRACE("Peel before 3 > iv"); + + run_test(SpvOpSGreaterThan, "%int_3", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 3u}}); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv < 8"); + + run_test(SpvOpSLessThan, "%38", "%int_8", + {{LoopPeelingPass::PeelDirection::kAfter, 2u}}); + } + { + SCOPED_TRACE("Peel after 8 > iv"); + + run_test(SpvOpSGreaterThan, "%int_8", "%38", + {{LoopPeelingPass::PeelDirection::kAfter, 2u}}); + } + + // Test GT + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv > 2"); + + run_test(SpvOpSGreaterThan, "%38", "%int_2", + {{LoopPeelingPass::PeelDirection::kBefore, 2u}}); + } + { + SCOPED_TRACE("Peel before 2 < iv"); + + run_test(SpvOpSLessThan, "%int_2", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 2u}}); + } + + // Peel after by a factor of 3. + { + SCOPED_TRACE("Peel after iv > 7"); + + run_test(SpvOpSGreaterThan, "%38", "%int_7", + {{LoopPeelingPass::PeelDirection::kAfter, 3u}}); + } + { + SCOPED_TRACE("Peel after 7 < iv"); + + run_test(SpvOpSLessThan, "%int_7", "%38", + {{LoopPeelingPass::PeelDirection::kAfter, 3u}}); + } + + // Test LE + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv <= 1"); + + run_test(SpvOpSLessThanEqual, "%38", "%int_1", + {{LoopPeelingPass::PeelDirection::kBefore, 2u}}); + } + { + SCOPED_TRACE("Peel before 1 => iv"); + + run_test(SpvOpSGreaterThanEqual, "%int_1", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 2u}}); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv <= 7"); + + run_test(SpvOpSLessThanEqual, "%38", "%int_7", + {{LoopPeelingPass::PeelDirection::kAfter, 2u}}); + } + { + SCOPED_TRACE("Peel after 7 => iv"); + + run_test(SpvOpSGreaterThanEqual, "%int_7", "%38", + {{LoopPeelingPass::PeelDirection::kAfter, 2u}}); + } + + // Test GE + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv >= 2"); + + run_test(SpvOpSGreaterThanEqual, "%38", "%int_2", + {{LoopPeelingPass::PeelDirection::kBefore, 2u}}); + } + { + SCOPED_TRACE("Peel before 2 <= iv"); + + run_test(SpvOpSLessThanEqual, "%int_2", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 2u}}); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv >= 8"); + + run_test(SpvOpSGreaterThanEqual, "%38", "%int_8", + {{LoopPeelingPass::PeelDirection::kAfter, 2u}}); + } + { + SCOPED_TRACE("Peel after 8 <= iv"); + + run_test(SpvOpSLessThanEqual, "%int_8", "%38", + {{LoopPeelingPass::PeelDirection::kAfter, 2u}}); + } + // Test EQ + // Peel before by a factor of 1. + { + SCOPED_TRACE("Peel before iv == 0"); + + run_test(SpvOpIEqual, "%38", "%int_0", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + { + SCOPED_TRACE("Peel before 0 == iv"); + + run_test(SpvOpIEqual, "%int_0", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + + // Peel after by a factor of 1. + { + SCOPED_TRACE("Peel after iv == 9"); + + run_test(SpvOpIEqual, "%38", "%int_9", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + { + SCOPED_TRACE("Peel after 9 == iv"); + + run_test(SpvOpIEqual, "%int_9", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + + // Test NE + // Peel before by a factor of 1. + { + SCOPED_TRACE("Peel before iv != 0"); + + run_test(SpvOpINotEqual, "%38", "%int_0", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + { + SCOPED_TRACE("Peel before 0 != iv"); + + run_test(SpvOpINotEqual, "%int_0", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + + // Peel after by a factor of 1. + { + SCOPED_TRACE("Peel after iv != 9"); + + run_test(SpvOpINotEqual, "%38", "%int_9", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + { + SCOPED_TRACE("Peel after 9 != iv"); + + run_test(SpvOpINotEqual, "%int_9", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } +} + +/* +Test are derivation of the following generated test from the following GLSL + +--eliminate-local-multi-store + +#version 330 core +void main() { + int a = 0; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + if (i < 3) { + a += 2; + } + } + } +} +*/ +TEST_F(PeelingPassTest, PeelingNestedPass) { + const std::string text_head = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + OpName %a "a" + OpName %i "i" + OpName %j "j" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %int_7 = OpConstant %int 7 + %int_3 = OpConstant %int 3 + %int_2 = OpConstant %int 2 + %int_1 = OpConstant %int 1 + %43 = OpUndef %int + %main = OpFunction %void None %3 + %5 = OpLabel + %a = OpVariable %_ptr_Function_int Function + %i = OpVariable %_ptr_Function_int Function + %j = OpVariable %_ptr_Function_int Function + OpStore %a %int_0 + OpStore %i %int_0 + OpBranch %11 + %11 = OpLabel + %41 = OpPhi %int %int_0 %5 %45 %14 + %42 = OpPhi %int %int_0 %5 %40 %14 + %44 = OpPhi %int %43 %5 %46 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %bool %42 %int_10 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + OpStore %j %int_0 + OpBranch %21 + %21 = OpLabel + %45 = OpPhi %int %41 %12 %47 %24 + %46 = OpPhi %int %int_0 %12 %38 %24 + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %27 = OpSLessThan %bool %46 %int_10 + OpBranchConditional %27 %22 %23 + %22 = OpLabel + )"; + + const std::string text_tail = R"( + OpSelectionMerge %32 None + OpBranchConditional %30 %31 %32 + %31 = OpLabel + %35 = OpIAdd %int %45 %int_2 + OpStore %a %35 + OpBranch %32 + %32 = OpLabel + %47 = OpPhi %int %45 %22 %35 %31 + OpBranch %24 + %24 = OpLabel + %38 = OpIAdd %int %46 %int_1 + OpStore %j %38 + OpBranch %21 + %23 = OpLabel + OpBranch %14 + %14 = OpLabel + %40 = OpIAdd %int %42 %int_1 + OpStore %i %40 + OpBranch %11 + %13 = OpLabel + OpReturn + OpFunctionEnd + )"; + + auto run_test = + [&text_head, &text_tail, this]( + SpvOp opcode, const std::string& op1, const std::string& op2, + const PeelTraceType& expected_peel_trace, size_t nb_of_loops) { + BuildAndCheckTrace(text_head, text_tail, opcode, "%30", op1, op2, + expected_peel_trace, nb_of_loops); + }; + + // Peeling outer before by a factor of 3. + { + SCOPED_TRACE("Peel before iv_i < 3"); + + // Expect peel before by a factor of 3 and 4 loops at the end. + run_test(SpvOpSLessThan, "%42", "%int_3", + {{LoopPeelingPass::PeelDirection::kBefore, 3u}}, 4); + } + // Peeling outer loop after by a factor of 3. + { + SCOPED_TRACE("Peel after iv_i < 7"); + + // Expect peel after by a factor of 3 and 4 loops at the end. + run_test(SpvOpSLessThan, "%42", "%int_7", + {{LoopPeelingPass::PeelDirection::kAfter, 3u}}, 4); + } + + // Peeling inner loop before by a factor of 3. + { + SCOPED_TRACE("Peel before iv_j < 3"); + + // Expect peel before by a factor of 3 and 3 loops at the end. + run_test(SpvOpSLessThan, "%46", "%int_3", + {{LoopPeelingPass::PeelDirection::kBefore, 3u}}, 3); + } + // Peeling inner loop after by a factor of 3. + { + SCOPED_TRACE("Peel after iv_j < 7"); + + // Expect peel after by a factor of 3 and 3 loops at the end. + run_test(SpvOpSLessThan, "%46", "%int_7", + {{LoopPeelingPass::PeelDirection::kAfter, 3u}}, 3); + } + + // Not unworkable condition. + { + SCOPED_TRACE("No peel"); + + // Expect no peeling and 2 loops at the end. + run_test(SpvOpSLessThan, "%46", "%42", {}, 2); + } + + // Could do a peeling of 3, but the goes over the threshold. + { + SCOPED_TRACE("Over threshold"); + + size_t current_threshold = LoopPeelingPass::GetLoopPeelingThreshold(); + LoopPeelingPass::SetLoopPeelingThreshold(1u); + // Expect no peeling and 2 loops at the end. + run_test(SpvOpSLessThan, "%46", "%int_7", {}, 2); + LoopPeelingPass::SetLoopPeelingThreshold(current_threshold); + } +} +/* +Test are derivation of the following generated test from the following GLSL + +--eliminate-local-multi-store + +#version 330 core +void main() { + int a = 0; + for (int i = 0, j = 0; i < 10; j++, i++) { + if (i < j) { + a += 2; + } + } +} +*/ +TEST_F(PeelingPassTest, PeelingNoChanges) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + OpName %a "a" + OpName %i "i" + OpName %j "j" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %int_2 = OpConstant %int 2 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + %a = OpVariable %_ptr_Function_int Function + %i = OpVariable %_ptr_Function_int Function + %j = OpVariable %_ptr_Function_int Function + OpStore %a %int_0 + OpStore %i %int_0 + OpStore %j %int_0 + OpBranch %12 + %12 = OpLabel + %34 = OpPhi %int %int_0 %5 %37 %15 + %35 = OpPhi %int %int_0 %5 %33 %15 + %36 = OpPhi %int %int_0 %5 %31 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %bool %35 %int_10 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + %23 = OpSLessThan %bool %35 %36 + OpSelectionMerge %25 None + OpBranchConditional %23 %24 %25 + %24 = OpLabel + %28 = OpIAdd %int %34 %int_2 + OpStore %a %28 + OpBranch %25 + %25 = OpLabel + %37 = OpPhi %int %34 %13 %28 %24 + OpBranch %15 + %15 = OpLabel + %31 = OpIAdd %int %36 %int_1 + OpStore %j %31 + %33 = OpIAdd %int %35 %int_1 + OpStore %i %33 + OpBranch %12 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + + { + auto result = + SinglePassRunAndDisassemble(text, true, false); + + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/unroll_assumptions.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/unroll_assumptions.cpp index 3a991c7bb..62f77d782 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/unroll_assumptions.cpp +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/unroll_assumptions.cpp @@ -12,37 +12,39 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include -#include - -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" -#include "opt/loop_unroller.h" -#include "opt/loop_utils.h" -#include "opt/pass.h" +#include "gmock/gmock.h" +#include "source/opt/loop_unroller.h" +#include "source/opt/loop_utils.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; template -class PartialUnrollerTestPass : public opt::Pass { +class PartialUnrollerTestPass : public Pass { public: PartialUnrollerTestPass() : Pass() {} const char* name() const override { return "Loop unroller"; } - Status Process(ir::IRContext* context) override { + Status Process() override { bool changed = false; - for (ir::Function& f : *context->module()) { - ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(&f); + for (Function& f : *context()->module()) { + LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(&f); for (auto& loop : loop_descriptor) { - opt::LoopUtils loop_utils{context, &loop}; + LoopUtils loop_utils{context(), &loop}; if (loop_utils.PartiallyUnroll(factor)) { changed = true; } @@ -54,8 +56,6 @@ class PartialUnrollerTestPass : public opt::Pass { } }; -using PassClassTest = PassTest<::testing::Test>; - /* Generated from the following GLSL #version 410 core @@ -68,7 +68,7 @@ void main() { */ TEST_F(PassClassTest, CheckUpperBound) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -120,18 +120,18 @@ OpReturn OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); // Make sure the pass doesn't run - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); SinglePassRunAndCheck>(text, text, false); SinglePassRunAndCheck>(text, text, false); } @@ -150,7 +150,7 @@ void main() { */ TEST_F(PassClassTest, UnrollNestedLoopsInvalid) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -220,16 +220,16 @@ OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); } /* @@ -247,7 +247,7 @@ void main(){ */ TEST_F(PassClassTest, BreakInBody) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -301,16 +301,16 @@ OpReturn OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); } /* @@ -328,7 +328,7 @@ void main(){ */ TEST_F(PassClassTest, ContinueInBody) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -382,16 +382,16 @@ OpReturn OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); } /* @@ -409,7 +409,7 @@ void main(){ */ TEST_F(PassClassTest, ReturnInBody) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -464,7 +464,7 @@ OpFunctionEnd )"; // clang-format on SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); } /* @@ -479,7 +479,7 @@ void main() { */ TEST_F(PassClassTest, MultipleConditionsSingleVariable) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -519,18 +519,18 @@ OpReturn OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); // Make sure the pass doesn't run - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); SinglePassRunAndCheck>(text, text, false); SinglePassRunAndCheck>(text, text, false); } @@ -549,7 +549,7 @@ void main() { */ TEST_F(PassClassTest, MultipleConditionsMultipleVariables) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -591,18 +591,18 @@ OpReturn OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); // Make sure the pass doesn't run - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); SinglePassRunAndCheck>(text, text, false); SinglePassRunAndCheck>(text, text, false); } @@ -620,7 +620,7 @@ void main() { */ TEST_F(PassClassTest, FloatingPointLoop) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -662,18 +662,18 @@ OpReturn OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); // Make sure the pass doesn't run - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); SinglePassRunAndCheck>(text, text, false); SinglePassRunAndCheck>(text, text, false); } @@ -692,7 +692,7 @@ void main() { */ TEST_F(PassClassTest, InductionPhiOutsideLoop) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -739,18 +739,18 @@ OpReturn OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); // Make sure the pass doesn't run - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); SinglePassRunAndCheck>(text, text, false); SinglePassRunAndCheck>(text, text, false); } @@ -791,7 +791,7 @@ void main() { */ TEST_F(PassClassTest, UnsupportedLoopTypes) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -957,18 +957,18 @@ OpReturn OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); // Make sure the pass doesn't run - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); SinglePassRunAndCheck>(text, text, false); SinglePassRunAndCheck>(text, text, false); } @@ -986,7 +986,7 @@ void main(void) { */ TEST_F(PassClassTest, NegativeNumberOfIterations) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -1031,18 +1031,18 @@ OpReturn OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); // Make sure the pass doesn't run - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); SinglePassRunAndCheck>(text, text, false); SinglePassRunAndCheck>(text, text, false); } @@ -1063,7 +1063,7 @@ void main(void) { */ TEST_F(PassClassTest, MultipleStepOperations) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -1111,18 +1111,18 @@ OpReturn OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); // Make sure the pass doesn't run - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); SinglePassRunAndCheck>(text, text, false); SinglePassRunAndCheck>(text, text, false); } @@ -1143,7 +1143,7 @@ void main(void) { TEST_F(PassClassTest, ConditionFalseFromStartGreaterThan) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -1190,18 +1190,18 @@ OpReturn OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); // Make sure the pass doesn't run - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); SinglePassRunAndCheck>(text, text, false); SinglePassRunAndCheck>(text, text, false); } @@ -1221,7 +1221,7 @@ void main(void) { */ TEST_F(PassClassTest, ConditionFalseFromStartGreaterThanOrEqual) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -1269,18 +1269,18 @@ OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); // Make sure the pass doesn't run - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); SinglePassRunAndCheck>(text, text, false); SinglePassRunAndCheck>(text, text, false); } @@ -1300,7 +1300,7 @@ void main(void) { */ TEST_F(PassClassTest, ConditionFalseFromStartLessThan) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -1348,18 +1348,18 @@ OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); // Make sure the pass doesn't run - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); SinglePassRunAndCheck>(text, text, false); SinglePassRunAndCheck>(text, text, false); } @@ -1379,7 +1379,7 @@ void main(void) { */ TEST_F(PassClassTest, ConditionFalseFromStartLessThanEqual) { // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -1427,20 +1427,22 @@ OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); // Make sure the pass doesn't run - SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck(text, text, false); SinglePassRunAndCheck>(text, text, false); SinglePassRunAndCheck>(text, text, false); } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/unroll_simple.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/unroll_simple.cpp index 3d8fb67f3..3b01fdc31 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/unroll_simple.cpp +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/unroll_simple.cpp @@ -12,24 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include -#include - -#include "../assembly_builder.h" -#include "../function_utils.h" -#include "../pass_fixture.h" -#include "../pass_utils.h" -#include "opt/loop_unroller.h" -#include "opt/loop_utils.h" -#include "opt/pass.h" +#include "gmock/gmock.h" +#include "source/opt/loop_unroller.h" +#include "source/opt/loop_utils.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; /* @@ -44,8 +44,7 @@ void main() { } */ TEST_F(PassClassTest, SimpleFullyUnrollTest) { - // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -98,8 +97,7 @@ TEST_F(PassClassTest, SimpleFullyUnrollTest) { OpFunctionEnd )"; -const std::string output = -R"(OpCapability Shader + const std::string output = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" %3 @@ -183,31 +181,31 @@ OpBranch %27 OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(text, output, false); + SinglePassRunAndCheck(text, output, false); } template -class PartialUnrollerTestPass : public opt::Pass { +class PartialUnrollerTestPass : public Pass { public: PartialUnrollerTestPass() : Pass() {} const char* name() const override { return "Loop unroller"; } - Status Process(ir::IRContext* context) override { - for (ir::Function& f : *context->module()) { - ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(&f); + Status Process() override { + for (Function& f : *context()->module()) { + LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(&f); for (auto& loop : loop_descriptor) { - opt::LoopUtils loop_utils{context, &loop}; + LoopUtils loop_utils{context(), &loop}; loop_utils.PartiallyUnroll(factor); } } @@ -228,8 +226,7 @@ void main() { } */ TEST_F(PassClassTest, SimplePartialUnroll) { - // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -344,15 +341,15 @@ OpBranch %23 OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); SinglePassRunAndCheck>(text, output, false); } @@ -369,8 +366,7 @@ void main() { } */ TEST_F(PassClassTest, SimpleUnevenPartialUnroll) { - // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -423,8 +419,7 @@ TEST_F(PassClassTest, SimpleUnevenPartialUnroll) { OpFunctionEnd )"; -const std::string output = -R"(OpCapability Shader + const std::string output = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" %3 @@ -517,15 +512,15 @@ OpReturn OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); // By unrolling by a factor that doesn't divide evenly into the number of loop // iterations we perfom an additional transform when partially unrolling to @@ -544,8 +539,7 @@ void main() { } */ TEST_F(PassClassTest, SimpleLoopIterationsCheck) { - // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -596,31 +590,30 @@ OpBranch %21 OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; - ir::Function* f = spvtest::GetFunction(module, 2); + Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); EXPECT_EQ(loop_descriptor.NumLoops(), 1u); - ir::Loop& loop = loop_descriptor.GetLoopByIndex(0); + Loop& loop = loop_descriptor.GetLoopByIndex(0); EXPECT_TRUE(loop.HasUnrollLoopControl()); - ir::BasicBlock* condition = loop.FindConditionBlock(); + BasicBlock* condition = loop.FindConditionBlock(); EXPECT_EQ(condition->id(), 24u); - ir::Instruction* induction = loop.FindConditionVariable(condition); + Instruction* induction = loop.FindConditionVariable(condition); EXPECT_EQ(induction->result_id(), 34u); - opt::LoopUtils loop_utils{context.get(), &loop}; + LoopUtils loop_utils{context.get(), &loop}; EXPECT_TRUE(loop_utils.CanPerformUnroll()); size_t iterations = 0; @@ -639,8 +632,7 @@ void main() { } */ TEST_F(PassClassTest, SimpleLoopIterationsCheckSignedInit) { - // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -692,32 +684,31 @@ OpBranch %22 OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; - ir::Function* f = spvtest::GetFunction(module, 2); + Function* f = spvtest::GetFunction(module, 2); - ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); EXPECT_EQ(loop_descriptor.NumLoops(), 1u); - ir::Loop& loop = loop_descriptor.GetLoopByIndex(0); + Loop& loop = loop_descriptor.GetLoopByIndex(0); EXPECT_FALSE(loop.HasUnrollLoopControl()); - ir::BasicBlock* condition = loop.FindConditionBlock(); + BasicBlock* condition = loop.FindConditionBlock(); EXPECT_EQ(condition->id(), 25u); - ir::Instruction* induction = loop.FindConditionVariable(condition); + Instruction* induction = loop.FindConditionVariable(condition); EXPECT_EQ(induction->result_id(), 35u); - opt::LoopUtils loop_utils{context.get(), &loop}; + LoopUtils loop_utils{context.get(), &loop}; EXPECT_TRUE(loop_utils.CanPerformUnroll()); size_t iterations = 0; @@ -739,8 +730,7 @@ void main() { } */ TEST_F(PassClassTest, UnrollNestedLoops) { - // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -809,8 +799,7 @@ TEST_F(PassClassTest, UnrollNestedLoops) { OpFunctionEnd )"; -const std::string output = -R"(OpCapability Shader + const std::string output = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" @@ -964,16 +953,16 @@ OpBranch %27 OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(text, output, false); + SinglePassRunAndCheck(text, output, false); } /* @@ -987,9 +976,8 @@ void main() { } */ TEST_F(PassClassTest, NegativeConditionAndInit) { - // clang-format off - // With opt::LocalMultiStoreElimPass -const std::string text = R"( + // With LocalMultiStoreElimPass + const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -1035,7 +1023,7 @@ const std::string text = R"( OpFunctionEnd )"; -const std::string expected = R"(OpCapability Shader + const std::string expected = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" @@ -1090,42 +1078,41 @@ OpBranch %22 OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - // SinglePassRunAndCheck(text, expected, false); + // SinglePassRunAndCheck(text, expected, false); - ir::Function* f = spvtest::GetFunction(module, 4); + Function* f = spvtest::GetFunction(module, 4); - ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); EXPECT_EQ(loop_descriptor.NumLoops(), 1u); - ir::Loop& loop = loop_descriptor.GetLoopByIndex(0); + Loop& loop = loop_descriptor.GetLoopByIndex(0); EXPECT_TRUE(loop.HasUnrollLoopControl()); - ir::BasicBlock* condition = loop.FindConditionBlock(); + BasicBlock* condition = loop.FindConditionBlock(); EXPECT_EQ(condition->id(), 14u); - ir::Instruction* induction = loop.FindConditionVariable(condition); + Instruction* induction = loop.FindConditionVariable(condition); EXPECT_EQ(induction->result_id(), 32u); - opt::LoopUtils loop_utils{context.get(), &loop}; + LoopUtils loop_utils{context.get(), &loop}; EXPECT_TRUE(loop_utils.CanPerformUnroll()); size_t iterations = 0; EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(), &iterations)); EXPECT_EQ(iterations, 2u); - SinglePassRunAndCheck(text, expected, false); + SinglePassRunAndCheck(text, expected, false); } /* @@ -1139,9 +1126,8 @@ void main() { } */ TEST_F(PassClassTest, NegativeConditionAndInitResidualUnroll) { - // clang-format off - // With opt::LocalMultiStoreElimPass -const std::string text = R"( + // With LocalMultiStoreElimPass + const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -1187,7 +1173,7 @@ const std::string text = R"( OpFunctionEnd )"; -const std::string expected = R"(OpCapability Shader + const std::string expected = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" @@ -1265,32 +1251,32 @@ OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - ir::Function* f = spvtest::GetFunction(module, 4); + Function* f = spvtest::GetFunction(module, 4); - ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); EXPECT_EQ(loop_descriptor.NumLoops(), 1u); - ir::Loop& loop = loop_descriptor.GetLoopByIndex(0); + Loop& loop = loop_descriptor.GetLoopByIndex(0); EXPECT_TRUE(loop.HasUnrollLoopControl()); - ir::BasicBlock* condition = loop.FindConditionBlock(); + BasicBlock* condition = loop.FindConditionBlock(); EXPECT_EQ(condition->id(), 14u); - ir::Instruction* induction = loop.FindConditionVariable(condition); + Instruction* induction = loop.FindConditionVariable(condition); EXPECT_EQ(induction->result_id(), 32u); - opt::LoopUtils loop_utils{context.get(), &loop}; + LoopUtils loop_utils{context.get(), &loop}; EXPECT_TRUE(loop_utils.CanPerformUnroll()); size_t iterations = 0; @@ -1313,8 +1299,7 @@ void main() { } */ TEST_F(PassClassTest, UnrollNestedLoopsValidateDescriptor) { - // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -1383,26 +1368,24 @@ TEST_F(PassClassTest, UnrollNestedLoopsValidateDescriptor) { OpFunctionEnd )"; - // clang-format on - { // Test fully unroll - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - ir::Function* f = spvtest::GetFunction(module, 4); - ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); EXPECT_EQ(loop_descriptor.NumLoops(), 2u); - ir::Loop& outer_loop = loop_descriptor.GetLoopByIndex(1); + Loop& outer_loop = loop_descriptor.GetLoopByIndex(1); EXPECT_TRUE(outer_loop.HasUnrollLoopControl()); - ir::Loop& inner_loop = loop_descriptor.GetLoopByIndex(0); + Loop& inner_loop = loop_descriptor.GetLoopByIndex(0); EXPECT_TRUE(inner_loop.HasUnrollLoopControl()); @@ -1413,7 +1396,7 @@ TEST_F(PassClassTest, UnrollNestedLoopsValidateDescriptor) { EXPECT_EQ(inner_loop.NumImmediateChildren(), 0u); { - opt::LoopUtils loop_utils{context.get(), &inner_loop}; + LoopUtils loop_utils{context.get(), &inner_loop}; loop_utils.FullyUnroll(); loop_utils.Finalize(); } @@ -1422,7 +1405,7 @@ TEST_F(PassClassTest, UnrollNestedLoopsValidateDescriptor) { EXPECT_EQ(outer_loop.GetBlocks().size(), 25u); EXPECT_EQ(outer_loop.NumImmediateChildren(), 0u); { - opt::LoopUtils loop_utils{context.get(), &outer_loop}; + LoopUtils loop_utils{context.get(), &outer_loop}; loop_utils.FullyUnroll(); loop_utils.Finalize(); } @@ -1430,23 +1413,23 @@ TEST_F(PassClassTest, UnrollNestedLoopsValidateDescriptor) { } { // Test partially unroll - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - ir::Function* f = spvtest::GetFunction(module, 4); - ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); EXPECT_EQ(loop_descriptor.NumLoops(), 2u); - ir::Loop& outer_loop = loop_descriptor.GetLoopByIndex(1); + Loop& outer_loop = loop_descriptor.GetLoopByIndex(1); EXPECT_TRUE(outer_loop.HasUnrollLoopControl()); - ir::Loop& inner_loop = loop_descriptor.GetLoopByIndex(0); + Loop& inner_loop = loop_descriptor.GetLoopByIndex(0); EXPECT_TRUE(inner_loop.HasUnrollLoopControl()); @@ -1457,7 +1440,7 @@ TEST_F(PassClassTest, UnrollNestedLoopsValidateDescriptor) { EXPECT_EQ(outer_loop.NumImmediateChildren(), 1u); EXPECT_EQ(inner_loop.NumImmediateChildren(), 0u); - opt::LoopUtils loop_utils{context.get(), &inner_loop}; + LoopUtils loop_utils{context.get(), &inner_loop}; loop_utils.PartiallyUnroll(2); loop_utils.Finalize(); @@ -1479,8 +1462,7 @@ void main() { } */ TEST_F(PassClassTest, FullyUnrollNegativeStepLoopTest) { - // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -1528,8 +1510,7 @@ TEST_F(PassClassTest, FullyUnrollNegativeStepLoopTest) { OpFunctionEnd )"; -const std::string output = -R"(OpCapability Shader + const std::string output = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" @@ -1598,17 +1579,17 @@ OpBranch %23 OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(text, output, false); + SinglePassRunAndCheck(text, output, false); } /* @@ -1622,8 +1603,7 @@ void main() { } */ TEST_F(PassClassTest, FullyUnrollNegativeNonOneStepLoop) { - // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -1671,8 +1651,7 @@ TEST_F(PassClassTest, FullyUnrollNegativeNonOneStepLoop) { OpFunctionEnd )"; -const std::string output = -R"(OpCapability Shader + const std::string output = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" @@ -1741,17 +1720,17 @@ OpBranch %23 OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(text, output, false); + SinglePassRunAndCheck(text, output, false); } /* @@ -1765,8 +1744,7 @@ void main() { } */ TEST_F(PassClassTest, FullyUnrollNonDivisibleStepLoop) { - // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -1813,8 +1791,7 @@ OpReturn OpFunctionEnd )"; -const std::string output = -R"(OpCapability Shader + const std::string output = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" @@ -1883,17 +1860,17 @@ OpBranch %23 OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(text, output, false); + SinglePassRunAndCheck(text, output, false); } /* @@ -1907,8 +1884,7 @@ void main() { } */ TEST_F(PassClassTest, FullyUnrollNegativeNonDivisibleStepLoop) { - // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -1955,8 +1931,7 @@ OpReturn OpFunctionEnd )"; -const std::string output = -R"(OpCapability Shader + const std::string output = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" @@ -2038,21 +2013,20 @@ OpBranch %23 OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(text, output, false); + SinglePassRunAndCheck(text, output, false); } -// clang-format off -// With opt::LocalMultiStoreElimPass +// With LocalMultiStoreElimPass static const std::string multiple_phi_shader = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -2102,12 +2076,9 @@ static const std::string multiple_phi_shader = R"( OpReturnValue %37 OpFunctionEnd )"; -// clang-format on TEST_F(PassClassTest, PartiallyUnrollResidualMultipleInductionVariables) { - // clang-format off -const std::string output = -R"(OpCapability Shader + const std::string output = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" @@ -2218,24 +2189,22 @@ OpReturnValue %45 OpReturnValue %30 OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, multiple_phi_shader, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << multiple_phi_shader << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); SinglePassRunAndCheck>(multiple_phi_shader, output, false); } TEST_F(PassClassTest, PartiallyUnrollMultipleInductionVariables) { - // clang-format off -const std::string output = -R"(OpCapability Shader + const std::string output = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" @@ -2296,24 +2265,22 @@ OpBranch %17 OpReturnValue %30 OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, multiple_phi_shader, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << multiple_phi_shader << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); SinglePassRunAndCheck>(multiple_phi_shader, output, false); } TEST_F(PassClassTest, FullyUnrollMultipleInductionVariables) { - // clang-format off -const std::string output = -R"(OpCapability Shader + const std::string output = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" @@ -2422,17 +2389,17 @@ OpBranch %25 OpReturnValue %30 OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, multiple_phi_shader, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << multiple_phi_shader << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(multiple_phi_shader, output, false); + SinglePassRunAndCheck(multiple_phi_shader, output, false); } /* @@ -2449,8 +2416,7 @@ void main() } */ TEST_F(PassClassTest, FullyUnrollEqualToOperations) { - // clang-format off - // With opt::LocalMultiStoreElimPass + // With LocalMultiStoreElimPass const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" @@ -2505,8 +2471,7 @@ TEST_F(PassClassTest, FullyUnrollEqualToOperations) { OpFunctionEnd )"; -const std::string output = -R"(OpCapability Shader + const std::string output = R"(OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" @@ -2585,22 +2550,21 @@ OpBranch %28 OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << text << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(text, output, false); + SinglePassRunAndCheck(text, output, false); } -// clang-format off - // With opt::LocalMultiStoreElimPass - const std::string condition_in_header = R"( +// With LocalMultiStoreElimPass +const std::string condition_in_header = R"( OpCapability Shader OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %main "main" %o @@ -2637,14 +2601,9 @@ OpFunctionEnd OpReturn OpFunctionEnd )"; -//clang-format on - TEST_F(PassClassTest, FullyUnrollConditionIsInHeaderBlock) { - -// clang-format off -const std::string output = -R"(OpCapability Shader + const std::string output = R"(OpCapability Shader OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %1 "main" %2 OpExecutionMode %1 OriginUpperLeft @@ -2700,23 +2659,21 @@ OpBranch %18 OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, condition_in_header, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << condition_in_header << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(condition_in_header, output, false); + SinglePassRunAndCheck(condition_in_header, output, false); } TEST_F(PassClassTest, PartiallyUnrollResidualConditionIsInHeaderBlock) { - // clang-format off -const std::string output = -R"(OpCapability Shader + const std::string output = R"(OpCapability Shader OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %1 "main" %2 OpExecutionMode %1 OriginUpperLeft @@ -2782,18 +2739,219 @@ OpReturn OpReturn OpFunctionEnd )"; - // clang-format on - std::unique_ptr context = + + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, condition_in_header, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" << condition_in_header << std::endl; - opt::LoopUnroller loop_unroller; + LoopUnroller loop_unroller; SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); SinglePassRunAndCheck>(condition_in_header, output, false); } +/* +Generated from following GLSL with latch block artificially inserted to be +seperate from continue. +#version 430 +void main(void) { + float x[10]; + for (int i = 0; i < 10; ++i) { + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, PartiallyUnrollLatchNotContinue) { + const std::string text = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "x" + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %7 = OpTypeInt 32 1 + %8 = OpTypePointer Function %7 + %9 = OpConstant %7 0 + %10 = OpConstant %7 10 + %11 = OpTypeBool + %12 = OpTypeFloat 32 + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 10 + %15 = OpTypeArray %12 %14 + %16 = OpTypePointer Function %15 + %17 = OpTypePointer Function %12 + %18 = OpConstant %7 1 + %2 = OpFunction %5 None %6 + %19 = OpLabel + %3 = OpVariable %8 Function + %4 = OpVariable %16 Function + OpStore %3 %9 + OpBranch %20 + %20 = OpLabel + %21 = OpPhi %7 %9 %19 %22 %30 + OpLoopMerge %24 %23 Unroll + OpBranch %25 + %25 = OpLabel + %26 = OpSLessThan %11 %21 %10 + OpBranchConditional %26 %27 %24 + %27 = OpLabel + %28 = OpConvertSToF %12 %21 + %29 = OpAccessChain %17 %4 %21 + OpStore %29 %28 + OpBranch %23 + %23 = OpLabel + %22 = OpIAdd %7 %21 %18 + OpStore %3 %22 + OpBranch %30 + %30 = OpLabel + OpBranch %20 + %24 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "x" +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpTypeInt 32 1 +%8 = OpTypePointer Function %7 +%9 = OpConstant %7 0 +%10 = OpConstant %7 10 +%11 = OpTypeBool +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpTypePointer Function %12 +%18 = OpConstant %7 1 +%63 = OpConstant %13 1 +%2 = OpFunction %5 None %6 +%19 = OpLabel +%3 = OpVariable %8 Function +%4 = OpVariable %16 Function +OpStore %3 %9 +OpBranch %20 +%20 = OpLabel +%21 = OpPhi %7 %9 %19 %22 %23 +OpLoopMerge %31 %25 Unroll +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %11 %21 %63 +OpBranchConditional %27 %28 %31 +%28 = OpLabel +%29 = OpConvertSToF %12 %21 +%30 = OpAccessChain %17 %4 %21 +OpStore %30 %29 +OpBranch %25 +%25 = OpLabel +%22 = OpIAdd %7 %21 %18 +OpStore %3 %22 +OpBranch %23 +%23 = OpLabel +OpBranch %20 +%31 = OpLabel +OpBranch %32 +%32 = OpLabel +%33 = OpPhi %7 %21 %31 %61 %62 +OpLoopMerge %42 %60 DontUnroll +OpBranch %34 +%34 = OpLabel +%35 = OpSLessThan %11 %33 %10 +OpBranchConditional %35 %36 %42 +%36 = OpLabel +%37 = OpConvertSToF %12 %33 +%38 = OpAccessChain %17 %4 %33 +OpStore %38 %37 +OpBranch %39 +%39 = OpLabel +%40 = OpIAdd %7 %33 %18 +OpStore %3 %40 +OpBranch %41 +%41 = OpLabel +OpBranch %43 +%43 = OpLabel +OpBranch %45 +%45 = OpLabel +%46 = OpSLessThan %11 %40 %10 +OpBranch %47 +%47 = OpLabel +%48 = OpConvertSToF %12 %40 +%49 = OpAccessChain %17 %4 %40 +OpStore %49 %48 +OpBranch %50 +%50 = OpLabel +%51 = OpIAdd %7 %40 %18 +OpStore %3 %51 +OpBranch %52 +%52 = OpLabel +OpBranch %53 +%53 = OpLabel +OpBranch %55 +%55 = OpLabel +%56 = OpSLessThan %11 %51 %10 +OpBranch %57 +%57 = OpLabel +%58 = OpConvertSToF %12 %51 +%59 = OpAccessChain %17 %4 %51 +OpStore %59 %58 +OpBranch %60 +%60 = OpLabel +%61 = OpIAdd %7 %51 %18 +OpStore %3 %61 +OpBranch %62 +%62 = OpLabel +OpBranch %32 +%42 = OpLabel +OpReturn +%24 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck>(text, expected, true); + + // Make sure the latch block information is preserved and propagated correctly + // by the pass. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + + PartialUnrollerTestPass<3> unroller; + unroller.SetContextForTesting(context.get()); + unroller.Process(); + + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; + + EXPECT_EQ(ld.NumLoops(), 2u); + + Loop& loop_1 = ld.GetLoopByIndex(0u); + EXPECT_NE(loop_1.GetLatchBlock(), loop_1.GetContinueBlock()); + + Loop& loop_2 = ld.GetLoopByIndex(1u); + EXPECT_NE(loop_2.GetLatchBlock(), loop_2.GetContinueBlock()); +} + } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/loop_optimizations/unswitch.cpp b/3rdparty/spirv-tools/test/opt/loop_optimizations/unswitch.cpp index d5c5209aa..96a7fc010 100644 --- a/3rdparty/spirv-tools/test/opt/loop_optimizations/unswitch.cpp +++ b/3rdparty/spirv-tools/test/opt/loop_optimizations/unswitch.cpp @@ -12,18 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include + +#include "gmock/gmock.h" +#include "test/opt/pass_fixture.h" #ifdef SPIRV_EFFCEE #include "effcee/effcee.h" #endif -#include "../pass_fixture.h" - +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using UnswitchTest = PassTest<::testing::Test>; #ifdef SPIRV_EFFCEE @@ -151,7 +152,7 @@ TEST_F(UnswitchTest, SimpleUnswitch) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } /* @@ -257,7 +258,7 @@ TEST_F(UnswitchTest, UnswitchExit) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } /* @@ -373,7 +374,7 @@ TEST_F(UnswitchTest, UnswitchContinue) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } /* @@ -479,7 +480,7 @@ TEST_F(UnswitchTest, UnswitchKillLoop) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } /* @@ -605,7 +606,7 @@ TEST_F(UnswitchTest, UnswitchSwitch) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } /* @@ -806,7 +807,7 @@ TEST_F(UnswitchTest, UnSwitchNested) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } #endif // SPIRV_EFFCEE @@ -906,9 +907,11 @@ TEST_F(UnswitchTest, UnswitchNotUniform) { )"; auto result = - SinglePassRunAndDisassemble(text, true, false); + SinglePassRunAndDisassemble(text, true, false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/module_test.cpp b/3rdparty/spirv-tools/test/opt/module_test.cpp index 177b45b96..c4f450ea9 100644 --- a/3rdparty/spirv-tools/test/opt/module_test.cpp +++ b/3rdparty/spirv-tools/test/opt/module_test.cpp @@ -12,25 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include +#include #include #include "gmock/gmock.h" #include "gtest/gtest.h" - -#include "message.h" -#include "opt/build_module.h" -#include "opt/module.h" +#include "source/opt/build_module.h" +#include "source/opt/module.h" #include "spirv-tools/libspirv.hpp" +#include "test/opt/module_utils.h" -#include "module_utils.h" - +namespace spvtools { +namespace opt { namespace { using ::testing::Eq; using spvtest::GetIdBound; -using spvtools::ir::IRContext; -using spvtools::ir::Module; TEST(ModuleTest, SetIdBound) { Module m; @@ -140,4 +139,6 @@ OpFunctionEnd)"; EXPECT_EQ(text, str.str()); } -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/module_utils.h b/3rdparty/spirv-tools/test/opt/module_utils.h index ad59058d1..007f132c2 100644 --- a/3rdparty/spirv-tools/test/opt/module_utils.h +++ b/3rdparty/spirv-tools/test/opt/module_utils.h @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_TEST_OPT_MODULE_UTILS_H_ -#define LIBSPIRV_TEST_OPT_MODULE_UTILS_H_ +#ifndef TEST_OPT_MODULE_UTILS_H_ +#define TEST_OPT_MODULE_UTILS_H_ #include -#include "opt/module.h" +#include "source/opt/module.h" namespace spvtest { -inline uint32_t GetIdBound(const spvtools::ir::Module& m) { +inline uint32_t GetIdBound(const spvtools::opt::Module& m) { std::vector binary; m.ToBinary(&binary, false); // The 5-word header must always exist. @@ -31,4 +31,4 @@ inline uint32_t GetIdBound(const spvtools::ir::Module& m) { } // namespace spvtest -#endif // LIBSPIRV_TEST_OPT_MODULE_UTILS_H_ +#endif // TEST_OPT_MODULE_UTILS_H_ diff --git a/3rdparty/spirv-tools/test/opt/optimizer_test.cpp b/3rdparty/spirv-tools/test/opt/optimizer_test.cpp index ef0e99213..90abc00d0 100644 --- a/3rdparty/spirv-tools/test/opt/optimizer_test.cpp +++ b/3rdparty/spirv-tools/test/opt/optimizer_test.cpp @@ -12,25 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include +#include +#include "gmock/gmock.h" #include "spirv-tools/libspirv.hpp" #include "spirv-tools/optimizer.hpp" +#include "test/opt/pass_fixture.h" -#include "pass_fixture.h" - +namespace spvtools { +namespace opt { namespace { -using spvtools::CreateNullPass; -using spvtools::CreateStripDebugInfoPass; -using spvtools::Optimizer; -using spvtools::SpirvTools; using ::testing::Eq; +// Return a string that contains the minimum instructions needed to form +// a valid module. Other instructions can be appended to this string. +std::string Header() { + return R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +)"; +} + TEST(Optimizer, CanRunNullPassWithDistinctInputOutputVectors) { SpirvTools tools(SPV_ENV_UNIVERSAL_1_0); std::vector binary_in; - tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary_in); + tools.Assemble(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid", + &binary_in); Optimizer opt(SPV_ENV_UNIVERSAL_1_0); opt.RegisterPass(CreateNullPass()); @@ -39,13 +48,15 @@ TEST(Optimizer, CanRunNullPassWithDistinctInputOutputVectors) { std::string disassembly; tools.Disassemble(binary_out.data(), binary_out.size(), &disassembly); - EXPECT_THAT(disassembly, Eq("OpName %foo \"foo\"\n%foo = OpTypeVoid\n")); + EXPECT_THAT(disassembly, + Eq(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid\n")); } TEST(Optimizer, CanRunTransformingPassWithDistinctInputOutputVectors) { SpirvTools tools(SPV_ENV_UNIVERSAL_1_0); std::vector binary_in; - tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary_in); + tools.Assemble(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid", + &binary_in); Optimizer opt(SPV_ENV_UNIVERSAL_1_0); opt.RegisterPass(CreateStripDebugInfoPass()); @@ -54,7 +65,7 @@ TEST(Optimizer, CanRunTransformingPassWithDistinctInputOutputVectors) { std::string disassembly; tools.Disassemble(binary_out.data(), binary_out.size(), &disassembly); - EXPECT_THAT(disassembly, Eq("%void = OpTypeVoid\n")); + EXPECT_THAT(disassembly, Eq(Header() + "%void = OpTypeVoid\n")); } TEST(Optimizer, CanRunNullPassWithAliasedVectors) { @@ -74,7 +85,7 @@ TEST(Optimizer, CanRunNullPassWithAliasedVectors) { TEST(Optimizer, CanRunNullPassWithAliasedVectorDataButDifferentSize) { SpirvTools tools(SPV_ENV_UNIVERSAL_1_0); std::vector binary; - tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary); + tools.Assemble(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary); Optimizer opt(SPV_ENV_UNIVERSAL_1_0); opt.RegisterPass(CreateNullPass()); @@ -89,13 +100,14 @@ TEST(Optimizer, CanRunNullPassWithAliasedVectorDataButDifferentSize) { std::string disassembly; tools.Disassemble(binary.data(), binary.size(), &disassembly); - EXPECT_THAT(disassembly, Eq("OpName %foo \"foo\"\n%foo = OpTypeVoid\n")); + EXPECT_THAT(disassembly, + Eq(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid\n")); } TEST(Optimizer, CanRunTransformingPassWithAliasedVectors) { SpirvTools tools(SPV_ENV_UNIVERSAL_1_0); std::vector binary; - tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary); + tools.Assemble(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary); Optimizer opt(SPV_ENV_UNIVERSAL_1_0); opt.RegisterPass(CreateStripDebugInfoPass()); @@ -103,7 +115,113 @@ TEST(Optimizer, CanRunTransformingPassWithAliasedVectors) { std::string disassembly; tools.Disassemble(binary.data(), binary.size(), &disassembly); - EXPECT_THAT(disassembly, Eq("%void = OpTypeVoid\n")); + EXPECT_THAT(disassembly, Eq(Header() + "%void = OpTypeVoid\n")); +} + +TEST(Optimizer, CanValidateFlags) { + Optimizer opt(SPV_ENV_UNIVERSAL_1_0); + EXPECT_FALSE(opt.FlagHasValidForm("bad-flag")); + EXPECT_TRUE(opt.FlagHasValidForm("-O")); + EXPECT_TRUE(opt.FlagHasValidForm("-Os")); + EXPECT_FALSE(opt.FlagHasValidForm("-O2")); + EXPECT_TRUE(opt.FlagHasValidForm("--this_flag")); +} + +TEST(Optimizer, CanRegisterPassesFromFlags) { + SpirvTools tools(SPV_ENV_UNIVERSAL_1_0); + Optimizer opt(SPV_ENV_UNIVERSAL_1_0); + + spv_message_level_t msg_level; + const char* msg_fname; + spv_position_t msg_position; + const char* msg; + auto examine_message = [&msg_level, &msg_fname, &msg_position, &msg]( + spv_message_level_t ml, const char* f, + const spv_position_t& p, const char* m) { + msg_level = ml; + msg_fname = f; + msg_position = p; + msg = m; + }; + opt.SetMessageConsumer(examine_message); + + std::vector pass_flags = { + "--strip-debug", + "--strip-reflect", + "--set-spec-const-default-value=23:42 21:12", + "--if-conversion", + "--freeze-spec-const", + "--inline-entry-points-exhaustive", + "--inline-entry-points-opaque", + "--convert-local-access-chains", + "--eliminate-dead-code-aggressive", + "--eliminate-insert-extract", + "--eliminate-local-single-block", + "--eliminate-local-single-store", + "--merge-blocks", + "--merge-return", + "--eliminate-dead-branches", + "--eliminate-dead-functions", + "--eliminate-local-multi-store", + "--eliminate-common-uniform", + "--eliminate-dead-const", + "--eliminate-dead-inserts", + "--eliminate-dead-variables", + "--fold-spec-const-op-composite", + "--loop-unswitch", + "--scalar-replacement=300", + "--scalar-replacement", + "--strength-reduction", + "--unify-const", + "--flatten-decorations", + "--compact-ids", + "--cfg-cleanup", + "--local-redundancy-elimination", + "--loop-invariant-code-motion", + "--reduce-load-size", + "--redundancy-elimination", + "--private-to-local", + "--remove-duplicates", + "--workaround-1209", + "--replace-invalid-opcode", + "--simplify-instructions", + "--ssa-rewrite", + "--copy-propagate-arrays", + "--loop-fission=20", + "--loop-fusion=2", + "--loop-unroll", + "--vector-dce", + "--loop-unroll-partial=3", + "--loop-peeling", + "--ccp", + "-O", + "-Os", + "--legalize-hlsl"}; + EXPECT_TRUE(opt.RegisterPassesFromFlags(pass_flags)); + + // Test some invalid flags. + EXPECT_FALSE(opt.RegisterPassFromFlag("-O2")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); + + EXPECT_FALSE(opt.RegisterPassFromFlag("-loop-unroll")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); + + EXPECT_FALSE(opt.RegisterPassFromFlag("--set-spec-const-default-value")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); + + EXPECT_FALSE(opt.RegisterPassFromFlag("--scalar-replacement=s")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); + + EXPECT_FALSE(opt.RegisterPassFromFlag("--loop-fission=-4")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); + + EXPECT_FALSE(opt.RegisterPassFromFlag("--loop-fusion=xx")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); + + EXPECT_FALSE(opt.RegisterPassFromFlag("--loop-unroll-partial")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/pass_fixture.h b/3rdparty/spirv-tools/test/opt/pass_fixture.h index d935231c9..9d9eb3661 100644 --- a/3rdparty/spirv-tools/test/opt/pass_fixture.h +++ b/3rdparty/spirv-tools/test/opt/pass_fixture.h @@ -12,27 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_TEST_OPT_PASS_FIXTURE_H_ -#define LIBSPIRV_TEST_OPT_PASS_FIXTURE_H_ +#ifndef TEST_OPT_PASS_FIXTURE_H_ +#define TEST_OPT_PASS_FIXTURE_H_ #include +#include #include #include +#include #include -#include +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/pass_manager.h" +#include "source/opt/passes.h" +#include "source/util/make_unique.h" +#include "spirv-tools/libspirv.hpp" #ifdef SPIRV_EFFCEE #include "effcee/effcee.h" #endif -#include "opt/build_module.h" -#include "opt/make_unique.h" -#include "opt/pass_manager.h" -#include "opt/passes.h" -#include "spirv-tools/libspirv.hpp" - namespace spvtools { +namespace opt { // Template class for testing passes. It contains some handy utility methods for // running passes and checking results. @@ -48,22 +50,21 @@ class PassTest : public TestT { : consumer_(nullptr), context_(nullptr), tools_(SPV_ENV_UNIVERSAL_1_1), - manager_(new opt::PassManager()), + manager_(new PassManager()), assemble_options_(SpirvTools::kDefaultAssembleOption), disassemble_options_(SpirvTools::kDefaultDisassembleOption) {} // Runs the given |pass| on the binary assembled from the |original|. // Returns a tuple of the optimized binary and the boolean value returned // from pass Process() function. - std::tuple, opt::Pass::Status> OptimizeToBinary( - opt::Pass* pass, const std::string& original, bool skip_nop) { + std::tuple, Pass::Status> OptimizeToBinary( + Pass* pass, const std::string& original, bool skip_nop) { context_ = std::move(BuildModule(SPV_ENV_UNIVERSAL_1_1, consumer_, original, assemble_options_)); EXPECT_NE(nullptr, context()) << "Assembling failed for shader:\n" << original << std::endl; if (!context()) { - return std::make_tuple(std::vector(), - opt::Pass::Status::Failure); + return std::make_tuple(std::vector(), Pass::Status::Failure); } const auto status = pass->Run(context()); @@ -77,7 +78,7 @@ class PassTest : public TestT { // |assembly|. Returns a tuple of the optimized binary and the boolean value // from the pass Process() function. template - std::tuple, opt::Pass::Status> SinglePassRunToBinary( + std::tuple, Pass::Status> SinglePassRunToBinary( const std::string& assembly, bool skip_nop, Args&&... args) { auto pass = MakeUnique(std::forward(args)...); pass->SetMessageConsumer(consumer_); @@ -88,11 +89,11 @@ class PassTest : public TestT { // |assembly|, disassembles the optimized binary. Returns a tuple of // disassembly string and the boolean value from the pass Process() function. template - std::tuple SinglePassRunAndDisassemble( + std::tuple SinglePassRunAndDisassemble( const std::string& assembly, bool skip_nop, bool do_validation, Args&&... args) { std::vector optimized_bin; - auto status = opt::Pass::Status::SuccessWithoutChange; + auto status = Pass::Status::SuccessWithoutChange; std::tie(optimized_bin, status) = SinglePassRunToBinary( assembly, skip_nop, std::forward(args)...); if (do_validation) { @@ -124,13 +125,13 @@ class PassTest : public TestT { const std::string& expected, bool skip_nop, bool do_validation, Args&&... args) { std::vector optimized_bin; - auto status = opt::Pass::Status::SuccessWithoutChange; + auto status = Pass::Status::SuccessWithoutChange; std::tie(optimized_bin, status) = SinglePassRunToBinary( original, skip_nop, std::forward(args)...); // Check whether the pass returns the correct modification indication. - EXPECT_NE(opt::Pass::Status::Failure, status); + EXPECT_NE(Pass::Status::Failure, status); EXPECT_EQ(original == expected, - status == opt::Pass::Status::SuccessWithoutChange); + status == Pass::Status::SuccessWithoutChange); if (do_validation) { spv_target_env target_env = SPV_ENV_UNIVERSAL_1_1; spv_context spvContext = spvContextCreate(target_env); @@ -190,7 +191,7 @@ class PassTest : public TestT { // Renews the pass manager, including clearing all previously added passes. void RenewPassManger() { - manager_.reset(new opt::PassManager()); + manager_ = MakeUnique(); manager_->SetMessageConsumer(consumer_); } @@ -224,7 +225,7 @@ class PassTest : public TestT { } MessageConsumer consumer() { return consumer_; } - ir::IRContext* context() { return context_.get(); } + IRContext* context() { return context_.get(); } void SetMessageConsumer(MessageConsumer msg_consumer) { consumer_ = msg_consumer; @@ -232,13 +233,14 @@ class PassTest : public TestT { private: MessageConsumer consumer_; // Message consumer. - std::unique_ptr context_; // IR context + std::unique_ptr context_; // IR context SpirvTools tools_; // An instance for calling SPIRV-Tools functionalities. - std::unique_ptr manager_; // The pass manager. + std::unique_ptr manager_; // The pass manager. uint32_t assemble_options_; uint32_t disassemble_options_; }; +} // namespace opt } // namespace spvtools -#endif // LIBSPIRV_TEST_OPT_PASS_FIXTURE_H_ +#endif // TEST_OPT_PASS_FIXTURE_H_ diff --git a/3rdparty/spirv-tools/test/opt/pass_manager_test.cpp b/3rdparty/spirv-tools/test/opt/pass_manager_test.cpp index 2066eff97..c7273e9c1 100644 --- a/3rdparty/spirv-tools/test/opt/pass_manager_test.cpp +++ b/3rdparty/spirv-tools/test/opt/pass_manager_test.cpp @@ -12,22 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gmock/gmock.h" - #include +#include +#include +#include +#include -#include "module_utils.h" -#include "opt/make_unique.h" -#include "pass_fixture.h" +#include "gmock/gmock.h" +#include "source/util/make_unique.h" +#include "test/opt/module_utils.h" +#include "test/opt/pass_fixture.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using spvtest::GetIdBound; using ::testing::Eq; // A null pass whose construtors accept arguments -class NullPassWithArgs : public opt::NullPass { +class NullPassWithArgs : public NullPass { public: NullPassWithArgs(uint32_t) {} NullPassWithArgs(std::string) {} @@ -38,19 +42,19 @@ class NullPassWithArgs : public opt::NullPass { }; TEST(PassManager, Interface) { - opt::PassManager manager; + PassManager manager; EXPECT_EQ(0u, manager.NumPasses()); - manager.AddPass(); + manager.AddPass(); EXPECT_EQ(1u, manager.NumPasses()); EXPECT_STREQ("strip-debug", manager.GetPass(0)->name()); - manager.AddPass(MakeUnique()); + manager.AddPass(MakeUnique()); EXPECT_EQ(2u, manager.NumPasses()); EXPECT_STREQ("strip-debug", manager.GetPass(0)->name()); EXPECT_STREQ("null", manager.GetPass(1)->name()); - manager.AddPass(); + manager.AddPass(); EXPECT_EQ(3u, manager.NumPasses()); EXPECT_STREQ("strip-debug", manager.GetPass(0)->name()); EXPECT_STREQ("null", manager.GetPass(1)->name()); @@ -71,25 +75,25 @@ TEST(PassManager, Interface) { } // A pass that appends an OpNop instruction to the debug1 section. -class AppendOpNopPass : public opt::Pass { +class AppendOpNopPass : public Pass { public: const char* name() const override { return "AppendOpNop"; } - Status Process(ir::IRContext* irContext) override { - irContext->AddDebug1Inst(MakeUnique(irContext)); + Status Process() override { + context()->AddDebug1Inst(MakeUnique(context())); return Status::SuccessWithChange; } }; // A pass that appends specified number of OpNop instructions to the debug1 // section. -class AppendMultipleOpNopPass : public opt::Pass { +class AppendMultipleOpNopPass : public Pass { public: explicit AppendMultipleOpNopPass(uint32_t num_nop) : num_nop_(num_nop) {} const char* name() const override { return "AppendOpNop"; } - Status Process(ir::IRContext* irContext) override { + Status Process() override { for (uint32_t i = 0; i < num_nop_; i++) { - irContext->AddDebug1Inst(MakeUnique(irContext)); + context()->AddDebug1Inst(MakeUnique(context())); } return Status::SuccessWithChange; } @@ -99,13 +103,13 @@ class AppendMultipleOpNopPass : public opt::Pass { }; // A pass that duplicates the last instruction in the debug1 section. -class DuplicateInstPass : public opt::Pass { +class DuplicateInstPass : public Pass { public: const char* name() const override { return "DuplicateInst"; } - Status Process(ir::IRContext* irContext) override { - auto inst = MakeUnique( - *(--irContext->debug1_end())->Clone(irContext)); - irContext->AddDebug1Inst(std::move(inst)); + Status Process() override { + auto inst = + MakeUnique(*(--context()->debug1_end())->Clone(context())); + context()->AddDebug1Inst(std::move(inst)); return Status::SuccessWithChange; } }; @@ -135,15 +139,15 @@ TEST_F(PassManagerTest, Run) { } // A pass that appends an OpTypeVoid instruction that uses a given id. -class AppendTypeVoidInstPass : public opt::Pass { +class AppendTypeVoidInstPass : public Pass { public: explicit AppendTypeVoidInstPass(uint32_t result_id) : result_id_(result_id) {} const char* name() const override { return "AppendTypeVoidInstPass"; } - Status Process(ir::IRContext* irContext) override { - auto inst = MakeUnique( - irContext, SpvOpTypeVoid, 0, result_id_, std::vector{}); - irContext->AddType(std::move(inst)); + Status Process() override { + auto inst = MakeUnique(context(), SpvOpTypeVoid, 0, result_id_, + std::vector{}); + context()->AddType(std::move(inst)); return Status::SuccessWithChange; } @@ -152,10 +156,10 @@ class AppendTypeVoidInstPass : public opt::Pass { }; TEST(PassManager, RecomputeIdBoundAutomatically) { - opt::PassManager manager; - std::unique_ptr module(new ir::Module()); - ir::IRContext context(SPV_ENV_UNIVERSAL_1_2, std::move(module), - manager.consumer()); + PassManager manager; + std::unique_ptr module(new Module()); + IRContext context(SPV_ENV_UNIVERSAL_1_2, std::move(module), + manager.consumer()); EXPECT_THAT(GetIdBound(*context.module()), Eq(0u)); manager.Run(&context); @@ -184,3 +188,5 @@ TEST(PassManager, RecomputeIdBoundAutomatically) { } } // anonymous namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/pass_merge_return_test.cpp b/3rdparty/spirv-tools/test/opt/pass_merge_return_test.cpp index a4f482f85..4dd4b6b28 100644 --- a/3rdparty/spirv-tools/test/opt/pass_merge_return_test.cpp +++ b/3rdparty/spirv-tools/test/opt/pass_merge_return_test.cpp @@ -12,18 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include +#include "gmock/gmock.h" #include "spirv-tools/libspirv.hpp" #include "spirv-tools/optimizer.hpp" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" -#include "pass_fixture.h" -#include "pass_utils.h" - +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using MergeReturnPassTest = PassTest<::testing::Test>; TEST_F(MergeReturnPassTest, OneReturn) { @@ -46,7 +46,7 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(before, after, false, true); + SinglePassRunAndCheck(before, after, false, true); } TEST_F(MergeReturnPassTest, TwoReturnsNoValue) { @@ -96,7 +96,7 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(before, after, false, true); + SinglePassRunAndCheck(before, after, false, true); } TEST_F(MergeReturnPassTest, TwoReturnsWithValues) { @@ -145,7 +145,7 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(before, after, false, true); + SinglePassRunAndCheck(before, after, false, true); } TEST_F(MergeReturnPassTest, UnreachableReturnsNoValue) { @@ -199,7 +199,7 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(before, after, false, true); + SinglePassRunAndCheck(before, after, false, true); } TEST_F(MergeReturnPassTest, UnreachableReturnsWithValues) { @@ -254,7 +254,7 @@ OpFunctionEnd SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); - SinglePassRunAndCheck(before, after, false, true); + SinglePassRunAndCheck(before, after, false, true); } #ifdef SPIRV_EFFCEE @@ -268,11 +268,11 @@ TEST_F(MergeReturnPassTest, StructuredControlFlowWithUnreachableMerge) { ; CHECK: OpSelectionMerge [[merge_lab:%\w+]] ; CHECK: OpBranchConditional [[cond:%\w+]] [[if_lab:%\w+]] [[then_lab:%\w+]] ; CHECK: [[if_lab]] = OpLabel -; CHECK-Next: OpStore [[var]] [[true]] -; CHECK-Next: OpBranch +; CHECK-NEXT: OpStore [[var]] [[true]] +; CHECK-NEXT: OpBranch ; CHECK: [[then_lab]] = OpLabel -; CHECK-Next: OpStore [[var]] [[true]] -; CHECK-Next: OpBranch [[merge_lab]] +; CHECK-NEXT: OpStore [[var]] [[true]] +; CHECK-NEXT: OpBranch [[merge_lab]] ; CHECK: OpReturn OpCapability Addresses OpCapability Shader @@ -296,7 +296,8 @@ OpUnreachable OpFunctionEnd )"; - SinglePassRunAndMatch(before, false); + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(before, false); } TEST_F(MergeReturnPassTest, StructuredControlFlowAddPhi) { @@ -310,11 +311,10 @@ TEST_F(MergeReturnPassTest, StructuredControlFlowAddPhi) { ; CHECK: OpBranchConditional [[cond:%\w+]] [[if_lab:%\w+]] [[then_lab:%\w+]] ; CHECK: [[if_lab]] = OpLabel ; CHECK-NEXT: [[add:%\w+]] = OpIAdd [[type:%\w+]] -; CHECK-Next: OpStore [[var]] [[true]] -; CHECK-Next: OpBranch +; CHECK-NEXT: OpBranch ; CHECK: [[then_lab]] = OpLabel -; CHECK-Next: OpStore [[var]] [[true]] -; CHECK-Next: OpBranch [[merge_lab]] +; CHECK-NEXT: OpStore [[var]] [[true]] +; CHECK-NEXT: OpBranch [[merge_lab]] ; CHECK: [[merge_lab]] = OpLabel ; CHECK-NEXT: [[phi:%\w+]] = OpPhi [[type]] [[add]] [[if_lab]] [[undef:%\w+]] [[then_lab]] ; CHECK: OpIAdd [[type]] [[phi]] [[phi]] @@ -345,8 +345,218 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(before, false); + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(before, false); } + +TEST_F(MergeReturnPassTest, StructuredControlDecoration) { + const std::string before = + R"( +; CHECK: OpDecorate [[dec_id:%\w+]] RelaxedPrecision +; CHECK: [[false:%\w+]] = OpConstantFalse +; CHECK: [[true:%\w+]] = OpConstantTrue +; CHECK: OpFunction +; CHECK: [[var:%\w+]] = OpVariable [[:%\w+]] Function [[false]] +; CHECK: OpSelectionMerge [[merge_lab:%\w+]] +; CHECK: OpBranchConditional [[cond:%\w+]] [[if_lab:%\w+]] [[then_lab:%\w+]] +; CHECK: [[if_lab]] = OpLabel +; CHECK-NEXT: [[dec_id]] = OpIAdd [[type:%\w+]] +; CHECK-NEXT: OpBranch +; CHECK: [[then_lab]] = OpLabel +; CHECK-NEXT: OpStore [[var]] [[true]] +; CHECK-NEXT: OpBranch [[merge_lab]] +; CHECK: [[merge_lab]] = OpLabel +; CHECK: OpReturn +OpCapability Addresses +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %6 "simple_shader" +OpDecorate %11 RelaxedPrecision +%2 = OpTypeVoid +%3 = OpTypeBool +%int = OpTypeInt 32 0 +%int_0 = OpConstant %int 0 +%4 = OpConstantFalse %3 +%1 = OpTypeFunction %2 +%6 = OpFunction %2 None %1 +%7 = OpLabel +OpSelectionMerge %10 None +OpBranchConditional %4 %8 %9 +%8 = OpLabel +%11 = OpIAdd %int %int_0 %int_0 +OpBranch %10 +%9 = OpLabel +OpReturn +%10 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(before, false); +} + +TEST_F(MergeReturnPassTest, StructuredControlDecoration2) { + const std::string before = + R"( +; CHECK: OpDecorate [[dec_id:%\w+]] RelaxedPrecision +; CHECK: [[false:%\w+]] = OpConstantFalse +; CHECK: [[true:%\w+]] = OpConstantTrue +; CHECK: OpFunction +; CHECK: [[var:%\w+]] = OpVariable [[:%\w+]] Function [[false]] +; CHECK: OpSelectionMerge [[merge_lab:%\w+]] +; CHECK: OpBranchConditional [[cond:%\w+]] [[if_lab:%\w+]] [[then_lab:%\w+]] +; CHECK: [[if_lab]] = OpLabel +; CHECK-NEXT: [[dec_id]] = OpIAdd [[type:%\w+]] +; CHECK-NEXT: OpBranch +; CHECK: [[then_lab]] = OpLabel +; CHECK-NEXT: OpStore [[var]] [[true]] +; CHECK-NEXT: OpBranch [[merge_lab]] +; CHECK: [[merge_lab]] = OpLabel +; CHECK-NEXT: [[phi:%\w+]] = OpPhi [[type]] [[dec_id]] [[if_lab]] [[undef:%\w+]] [[then_lab]] +; CHECK: OpIAdd [[type]] [[phi]] [[phi]] +; CHECK: OpReturn +OpCapability Addresses +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %6 "simple_shader" +OpDecorate %11 RelaxedPrecision +%2 = OpTypeVoid +%3 = OpTypeBool +%int = OpTypeInt 32 0 +%int_0 = OpConstant %int 0 +%4 = OpConstantFalse %3 +%1 = OpTypeFunction %2 +%6 = OpFunction %2 None %1 +%7 = OpLabel +OpSelectionMerge %10 None +OpBranchConditional %4 %8 %9 +%8 = OpLabel +%11 = OpIAdd %int %int_0 %int_0 +OpBranch %10 +%9 = OpLabel +OpReturn +%10 = OpLabel +%12 = OpIAdd %int %11 %11 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(before, false); +} + +TEST_F(MergeReturnPassTest, SplitBlockUsedInPhi) { + const std::string before = + R"( +; CHECK: OpFunction +; CHECK-NEXT: OpLabel +; CHECK: OpSelectionMerge [[merge1:%\w+]] None +; CHECK: [[merge1]] = OpLabel +; CHECK: OpBranchConditional %{{\w+}} %{{\w+}} [[old_merge:%\w+]] +; CHECK: [[old_merge]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[merge2:%\w+]] +; CHECK-NEXT: OpBranchConditional %false [[side_node:%\w+]] [[merge2]] +; CHECK: [[merge2]] = OpLabel +; CHECK-NEXT: OpPhi %bool %false [[old_merge]] %true [[side_node]] + OpCapability Addresses + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "simple_shader" + %void = OpTypeVoid + %bool = OpTypeBool + %false = OpConstantFalse %bool + %true = OpConstantTrue %bool + %6 = OpTypeFunction %void + %1 = OpFunction %void None %6 + %7 = OpLabel + OpSelectionMerge %8 None + OpBranchConditional %false %9 %8 + %9 = OpLabel + OpReturn + %8 = OpLabel + OpSelectionMerge %10 None + OpBranchConditional %false %11 %10 + %11 = OpLabel + OpBranch %10 + %10 = OpLabel + %12 = OpPhi %bool %false %8 %true %11 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(before, false); +} + +// TODO(#1861): Reenable these test when the breaks from selection constructs +// are reenabled. +/* +TEST_F(MergeReturnPassTest, UpdateOrderWhenPredicating) { + const std::string before = + R"( +; CHECK: OpFunction +; CHECK: OpFunction +; CHECK: OpSelectionMerge [[m1:%\w+]] None +; CHECK-NOT: OpReturn +; CHECK: [[m1]] = OpLabel +; CHECK: OpSelectionMerge [[m2:%\w+]] None +; CHECK: OpSelectionMerge [[m3:%\w+]] None +; CHECK: OpSelectionMerge [[m4:%\w+]] None +; CHECK: OpLabel +; CHECK-NEXT: OpStore +; CHECK-NEXT: OpBranch [[m4]] +; CHECK: [[m4]] = OpLabel +; CHECK-NEXT: [[ld4:%\w+]] = OpLoad %bool +; CHECK-NEXT: OpBranchConditional [[ld4]] [[m3]] +; CHECK: [[m3]] = OpLabel +; CHECK-NEXT: [[ld3:%\w+]] = OpLoad %bool +; CHECK-NEXT: OpBranchConditional [[ld3]] [[m2]] +; CHECK: [[m2]] = OpLabel + OpCapability SampledBuffer + OpCapability StorageImageExtendedFormats + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "PS_DebugTiles" + OpExecutionMode %1 OriginUpperLeft + OpSource HLSL 600 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %bool = OpTypeBool + %1 = OpFunction %void None %3 + %5 = OpLabel + %6 = OpFunctionCall %void %7 + OpReturn + OpFunctionEnd + %7 = OpFunction %void None %3 + %8 = OpLabel + %9 = OpUndef %bool + OpSelectionMerge %10 None + OpBranchConditional %9 %11 %10 + %11 = OpLabel + OpReturn + %10 = OpLabel + %12 = OpUndef %bool + OpSelectionMerge %13 None + OpBranchConditional %12 %14 %15 + %15 = OpLabel + %16 = OpUndef %bool + OpSelectionMerge %17 None + OpBranchConditional %16 %18 %17 + %18 = OpLabel + OpReturn + %17 = OpLabel + OpBranch %13 + %14 = OpLabel + OpReturn + %13 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(before, false); +} +*/ #endif TEST_F(MergeReturnPassTest, StructuredControlFlowBothMergeAndHeader) { @@ -429,9 +639,12 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before, after, false, true); + SinglePassRunAndCheck(before, after, false, true); } +// TODO(#1861): Reenable these test when the breaks from selection constructs +// are reenabled. +/* TEST_F(MergeReturnPassTest, NestedSelectionMerge) { const std::string before = R"( @@ -454,16 +667,16 @@ TEST_F(MergeReturnPassTest, NestedSelectionMerge) { OpReturn %11 = OpLabel OpSelectionMerge %12 None - OpBranchConditional %false %14 %15 - %14 = OpLabel - %16 = OpIAdd %uint %uint_0 %uint_0 + OpBranchConditional %false %13 %14 + %13 = OpLabel + %15 = OpIAdd %uint %uint_0 %uint_0 OpBranch %12 - %15 = OpLabel + %14 = OpLabel OpReturn %12 = OpLabel OpBranch %9 %9 = OpLabel - %17 = OpIAdd %uint %16 %16 + %16 = OpIAdd %uint %15 %15 OpReturn OpFunctionEnd )"; @@ -482,7 +695,7 @@ OpEntryPoint GLCompute %1 "simple_shader" %7 = OpTypeFunction %void %_ptr_Function_bool = OpTypePointer Function %bool %true = OpConstantTrue %bool -%24 = OpUndef %uint +%26 = OpUndef %uint %1 = OpFunction %void None %7 %8 = OpLabel %19 = OpVariable %_ptr_Function_bool Function %false @@ -501,25 +714,29 @@ OpBranch %12 OpStore %19 %true OpBranch %12 %12 = OpLabel -%25 = OpPhi %uint %15 %13 %24 %14 +%27 = OpPhi %uint %15 %13 %26 %14 +%22 = OpLoad %bool %19 +OpBranchConditional %22 %9 %21 +%21 = OpLabel OpBranch %9 %9 = OpLabel -%26 = OpPhi %uint %25 %12 %24 %10 -%23 = OpLoad %bool %19 -OpSelectionMerge %22 None -OpBranchConditional %23 %22 %21 -%21 = OpLabel -%16 = OpIAdd %uint %26 %26 +%28 = OpPhi %uint %27 %21 %26 %10 %26 %12 +%25 = OpLoad %bool %19 +OpSelectionMerge %24 None +OpBranchConditional %25 %24 %23 +%23 = OpLabel +%16 = OpIAdd %uint %28 %28 OpStore %19 %true -OpBranch %22 -%22 = OpLabel +OpBranch %24 +%24 = OpLabel OpBranch %17 %17 = OpLabel OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before, after, false, true); + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, false, true); } // This is essentially the same as NestedSelectionMerge, except @@ -527,8 +744,7 @@ OpFunctionEnd // work even if the order of the traversals change. TEST_F(MergeReturnPassTest, NestedSelectionMerge2) { const std::string before = - R"( - OpCapability Addresses + R"( OpCapability Addresses OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -547,16 +763,16 @@ TEST_F(MergeReturnPassTest, NestedSelectionMerge2) { OpReturn %10 = OpLabel OpSelectionMerge %12 None - OpBranchConditional %false %14 %15 - %14 = OpLabel - %16 = OpIAdd %uint %uint_0 %uint_0 + OpBranchConditional %false %13 %14 + %13 = OpLabel + %15 = OpIAdd %uint %uint_0 %uint_0 OpBranch %12 - %15 = OpLabel + %14 = OpLabel OpReturn %12 = OpLabel OpBranch %9 %9 = OpLabel - %17 = OpIAdd %uint %16 %16 + %16 = OpIAdd %uint %15 %15 OpReturn OpFunctionEnd )"; @@ -575,7 +791,7 @@ OpEntryPoint GLCompute %1 "simple_shader" %7 = OpTypeFunction %void %_ptr_Function_bool = OpTypePointer Function %bool %true = OpConstantTrue %bool -%24 = OpUndef %uint +%26 = OpUndef %uint %1 = OpFunction %void None %7 %8 = OpLabel %19 = OpVariable %_ptr_Function_bool Function %false @@ -594,15 +810,18 @@ OpBranch %12 OpStore %19 %true OpBranch %12 %12 = OpLabel -%25 = OpPhi %uint %15 %13 %24 %14 +%27 = OpPhi %uint %15 %13 %26 %14 +%25 = OpLoad %bool %19 +OpBranchConditional %25 %9 %24 +%24 = OpLabel OpBranch %9 %9 = OpLabel -%26 = OpPhi %uint %25 %12 %24 %11 +%28 = OpPhi %uint %27 %24 %26 %11 %26 %12 %23 = OpLoad %bool %19 OpSelectionMerge %22 None OpBranchConditional %23 %22 %21 %21 = OpLabel -%16 = OpIAdd %uint %26 %26 +%16 = OpIAdd %uint %28 %28 OpStore %19 %true OpBranch %22 %22 = OpLabel @@ -612,13 +831,12 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before, after, false, true); + SinglePassRunAndCheck(before, after, false, true); } TEST_F(MergeReturnPassTest, NestedSelectionMerge3) { const std::string before = - R"( - OpCapability Addresses + R"( OpCapability Addresses OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -636,17 +854,17 @@ TEST_F(MergeReturnPassTest, NestedSelectionMerge3) { %11 = OpLabel OpReturn %10 = OpLabel - %16 = OpIAdd %uint %uint_0 %uint_0 - OpSelectionMerge %12 None + %12 = OpIAdd %uint %uint_0 %uint_0 + OpSelectionMerge %13 None OpBranchConditional %false %14 %15 %14 = OpLabel - OpBranch %12 + OpBranch %13 %15 = OpLabel OpReturn - %12 = OpLabel + %13 = OpLabel OpBranch %9 %9 = OpLabel - %17 = OpIAdd %uint %16 %16 + %16 = OpIAdd %uint %12 %12 OpReturn OpFunctionEnd )"; @@ -665,7 +883,7 @@ OpEntryPoint GLCompute %1 "simple_shader" %7 = OpTypeFunction %void %_ptr_Function_bool = OpTypePointer Function %bool %true = OpConstantTrue %bool -%24 = OpUndef %uint +%26 = OpUndef %uint %1 = OpFunction %void None %7 %8 = OpLabel %19 = OpVariable %_ptr_Function_bool Function %false @@ -684,14 +902,17 @@ OpBranch %13 OpStore %19 %true OpBranch %13 %13 = OpLabel +%25 = OpLoad %bool %19 +OpBranchConditional %25 %9 %24 +%24 = OpLabel OpBranch %9 %9 = OpLabel -%25 = OpPhi %uint %12 %13 %24 %11 +%27 = OpPhi %uint %12 %24 %26 %11 %26 %13 %23 = OpLoad %bool %19 OpSelectionMerge %22 None OpBranchConditional %23 %22 %21 %21 = OpLabel -%16 = OpIAdd %uint %25 %25 +%16 = OpIAdd %uint %27 %27 OpStore %19 %true OpBranch %22 %22 = OpLabel @@ -701,6 +922,238 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndCheck(before, after, false, true); + SinglePassRunAndCheck(before, after, false, true); } -} // anonymous namespace +*/ + +TEST_F(MergeReturnPassTest, NestedLoopMerge) { + const std::string before = + R"( OpCapability SampledBuffer + OpCapability StorageImageExtendedFormats + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %2 "CS" + OpExecutionMode %2 LocalSize 8 8 1 + OpSource HLSL 600 + OpName %function "function" + %uint = OpTypeInt 32 0 + %void = OpTypeVoid + %6 = OpTypeFunction %void + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %v3uint = OpTypeVector %uint 3 + %bool = OpTypeBool + %true = OpConstantTrue %bool +%_ptr_Function_uint = OpTypePointer Function %uint + %_struct_13 = OpTypeStruct %v3uint %v3uint %v3uint %uint %uint %uint %uint %uint %uint + %2 = OpFunction %void None %6 + %14 = OpLabel + %15 = OpFunctionCall %void %function + OpReturn + OpFunctionEnd + %function = OpFunction %void None %6 + %16 = OpLabel + %17 = OpVariable %_ptr_Function_uint Function + %18 = OpVariable %_ptr_Function_uint Function + OpStore %17 %uint_0 + OpBranch %19 + %19 = OpLabel + %20 = OpLoad %uint %17 + %21 = OpULessThan %bool %20 %uint_1 + OpLoopMerge %22 %23 DontUnroll + OpBranchConditional %21 %24 %22 + %24 = OpLabel + OpStore %18 %uint_1 + OpBranch %25 + %25 = OpLabel + %26 = OpLoad %uint %18 + %27 = OpINotEqual %bool %26 %uint_0 + OpLoopMerge %28 %29 DontUnroll + OpBranchConditional %27 %30 %28 + %30 = OpLabel + OpSelectionMerge %31 None + OpBranchConditional %true %32 %31 + %32 = OpLabel + OpReturn + %31 = OpLabel + OpStore %18 %uint_1 + OpBranch %29 + %29 = OpLabel + OpBranch %25 + %28 = OpLabel + OpBranch %23 + %23 = OpLabel + %33 = OpLoad %uint %17 + %34 = OpIAdd %uint %33 %uint_1 + OpStore %17 %34 + OpBranch %19 + %22 = OpLabel + OpReturn + OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability SampledBuffer +OpCapability StorageImageExtendedFormats +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %2 "CS" +OpExecutionMode %2 LocalSize 8 8 1 +OpSource HLSL 600 +OpName %function "function" +%uint = OpTypeInt 32 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%v3uint = OpTypeVector %uint 3 +%bool = OpTypeBool +%true = OpConstantTrue %bool +%_ptr_Function_uint = OpTypePointer Function %uint +%_struct_13 = OpTypeStruct %v3uint %v3uint %v3uint %uint %uint %uint %uint %uint %uint +%false = OpConstantFalse %bool +%_ptr_Function_bool = OpTypePointer Function %bool +%2 = OpFunction %void None %6 +%14 = OpLabel +%15 = OpFunctionCall %void %function +OpReturn +OpFunctionEnd +%function = OpFunction %void None %6 +%16 = OpLabel +%38 = OpVariable %_ptr_Function_bool Function %false +%17 = OpVariable %_ptr_Function_uint Function +%18 = OpVariable %_ptr_Function_uint Function +OpStore %17 %uint_0 +OpBranch %19 +%19 = OpLabel +%20 = OpLoad %uint %17 +%21 = OpULessThan %bool %20 %uint_1 +OpLoopMerge %22 %23 DontUnroll +OpBranchConditional %21 %24 %22 +%24 = OpLabel +OpStore %18 %uint_1 +OpBranch %25 +%25 = OpLabel +%26 = OpLoad %uint %18 +%27 = OpINotEqual %bool %26 %uint_0 +OpLoopMerge %28 %29 DontUnroll +OpBranchConditional %27 %30 %28 +%30 = OpLabel +OpSelectionMerge %31 None +OpBranchConditional %true %32 %31 +%32 = OpLabel +OpStore %38 %true +OpBranch %28 +%31 = OpLabel +OpStore %18 %uint_1 +OpBranch %29 +%29 = OpLabel +OpBranch %25 +%28 = OpLabel +%40 = OpLoad %bool %38 +OpBranchConditional %40 %22 %39 +%39 = OpLabel +OpBranch %23 +%23 = OpLabel +%33 = OpLoad %uint %17 +%34 = OpIAdd %uint %33 %uint_1 +OpStore %17 %34 +OpBranch %19 +%22 = OpLabel +%43 = OpLoad %bool %38 +OpSelectionMerge %42 None +OpBranchConditional %43 %42 %41 +%41 = OpLabel +OpStore %38 %true +OpBranch %42 +%42 = OpLabel +OpBranch %35 +%35 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before, after, false, true); +} + +TEST_F(MergeReturnPassTest, ReturnValueDecoration) { + const std::string before = + R"(OpCapability Linkage +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %11 "simple_shader" +OpDecorate %7 RelaxedPrecision +%12 = OpTypeVoid +%1 = OpTypeInt 32 0 +%2 = OpTypeBool +%3 = OpConstantFalse %2 +%4 = OpConstant %1 0 +%5 = OpConstant %1 1 +%6 = OpTypeFunction %1 +%13 = OpTypeFunction %12 +%11 = OpFunction %12 None %13 +%l1 = OpLabel +OpReturn +OpFunctionEnd +%7 = OpFunction %1 None %6 +%8 = OpLabel +OpBranchConditional %3 %9 %10 +%9 = OpLabel +OpReturnValue %4 +%10 = OpLabel +OpReturnValue %5 +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Linkage +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %11 "simple_shader" +OpDecorate %7 RelaxedPrecision +OpDecorate %17 RelaxedPrecision +OpDecorate %18 RelaxedPrecision +%12 = OpTypeVoid +%1 = OpTypeInt 32 0 +%2 = OpTypeBool +%3 = OpConstantFalse %2 +%4 = OpConstant %1 0 +%5 = OpConstant %1 1 +%6 = OpTypeFunction %1 +%13 = OpTypeFunction %12 +%16 = OpTypePointer Function %1 +%19 = OpTypePointer Function %2 +%21 = OpConstantTrue %2 +%11 = OpFunction %12 None %13 +%14 = OpLabel +OpReturn +OpFunctionEnd +%7 = OpFunction %1 None %6 +%8 = OpLabel +%20 = OpVariable %19 Function %3 +%17 = OpVariable %16 Function +OpBranchConditional %3 %9 %10 +%9 = OpLabel +OpStore %20 %21 +OpStore %17 %4 +OpBranch %15 +%10 = OpLabel +OpStore %20 %21 +OpStore %17 %5 +OpBranch %15 +%15 = OpLabel +%18 = OpLoad %1 %17 +OpReturnValue %18 +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(before, after, false, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/pass_remove_duplicates_test.cpp b/3rdparty/spirv-tools/test/opt/pass_remove_duplicates_test.cpp index d269daa41..887fdfdb4 100644 --- a/3rdparty/spirv-tools/test/opt/pass_remove_duplicates_test.cpp +++ b/3rdparty/spirv-tools/test/opt/pass_remove_duplicates_test.cpp @@ -13,23 +13,22 @@ // limitations under the License. #include +#include +#include +#include -#include - +#include "gmock/gmock.h" #include "source/opt/build_module.h" #include "source/opt/ir_context.h" #include "source/opt/pass_manager.h" #include "source/opt/remove_duplicates_pass.h" #include "source/spirv_constant.h" -#include "unit_spirv.h" +#include "test/unit_spirv.h" +namespace spvtools { +namespace opt { namespace { -using spvtools::ir::IRContext; -using spvtools::ir::Instruction; -using spvtools::opt::PassManager; -using spvtools::opt::RemoveDuplicatesPass; - class RemoveDuplicatesTest : public ::testing::Test { public: RemoveDuplicatesTest() @@ -62,7 +61,7 @@ class RemoveDuplicatesTest : public ::testing::Test { tools_.SetMessageConsumer(consumer_); } - virtual void TearDown() override { error_message_.clear(); } + void TearDown() override { error_message_.clear(); } std::string RunPass(const std::string& text) { context_ = spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_2, consumer_, text); @@ -72,8 +71,8 @@ class RemoveDuplicatesTest : public ::testing::Test { manager.SetMessageConsumer(consumer_); manager.AddPass(); - spvtools::opt::Pass::Status pass_res = manager.Run(context_.get()); - if (pass_res == spvtools::opt::Pass::Status::Failure) return std::string(); + Pass::Status pass_res = manager.Run(context_.get()); + if (pass_res == Pass::Status::Failure) return std::string(); return ModuleToText(); } @@ -129,8 +128,8 @@ OpCapability Linkage OpMemoryModel Logical GLSL450 )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } TEST_F(RemoveDuplicatesTest, DuplicateExtInstImports) { @@ -149,8 +148,8 @@ OpCapability Linkage OpMemoryModel Logical GLSL450 )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } TEST_F(RemoveDuplicatesTest, DuplicateTypes) { @@ -169,8 +168,8 @@ OpMemoryModel Logical GLSL450 %3 = OpTypeStruct %1 %1 )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } TEST_F(RemoveDuplicatesTest, SameTypeDifferentMemberDecoration) { @@ -192,8 +191,8 @@ OpDecorate %1 GLSLPacked %3 = OpTypeStruct %2 %2 )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } TEST_F(RemoveDuplicatesTest, SameTypeAndMemberDecoration) { @@ -215,8 +214,8 @@ OpDecorate %1 GLSLPacked %1 = OpTypeStruct %3 %3 )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } TEST_F(RemoveDuplicatesTest, SameTypeAndDifferentName) { @@ -238,8 +237,8 @@ OpName %1 "Type1" %1 = OpTypeStruct %3 %3 )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } // Check that #1033 has been fixed. @@ -268,8 +267,8 @@ OpGroupDecorate %3 %1 %2 %3 = OpVariable %4 Uniform )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } TEST_F(RemoveDuplicatesTest, DifferentDecorationGroup) { @@ -303,8 +302,345 @@ OpGroupDecorate %2 %4 %4 = OpVariable %5 Uniform )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// Test what happens when a type is a resource type. For now we are merging +// them, but, if we want to merge types and make reflection work (issue #1372), +// we will not be able to merge %2 and %3 below. +TEST_F(RemoveDuplicatesTest, DontMergeNestedResourceTypes) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpMemberName %3 0 "AdjustXYZ" +OpMemberName %3 1 "AdjustDir" +OpName %4 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpMemberDecorate %3 0 Offset 0 +OpMemberDecorate %3 1 Offset 16 +OpDecorate %3 Block +OpDecorate %4 DescriptorSet 0 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%2 = OpTypeStruct %6 +%3 = OpTypeStruct %1 %2 +%7 = OpTypePointer Uniform %3 +%4 = OpVariable %7 Uniform +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpMemberName %3 0 "AdjustXYZ" +OpMemberName %3 1 "AdjustDir" +OpName %4 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %3 0 Offset 0 +OpMemberDecorate %3 1 Offset 16 +OpDecorate %3 Block +OpDecorate %4 DescriptorSet 0 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%3 = OpTypeStruct %1 %1 +%7 = OpTypePointer Uniform %3 +%4 = OpVariable %7 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// See comment for DontMergeNestedResourceTypes. +TEST_F(RemoveDuplicatesTest, DontMergeResourceTypes) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +OpDecorate %4 DescriptorSet 1 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%2 = OpTypeStruct %6 +%7 = OpTypePointer Uniform %1 +%8 = OpTypePointer Uniform %2 +%3 = OpVariable %7 Uniform +%4 = OpVariable %8 Uniform +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +OpDecorate %4 DescriptorSet 1 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%7 = OpTypePointer Uniform %1 +%3 = OpVariable %7 Uniform +%4 = OpVariable %7 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// See comment for DontMergeNestedResourceTypes. +TEST_F(RemoveDuplicatesTest, DontMergeResourceTypesContainingArray) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +OpDecorate %4 DescriptorSet 1 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%2 = OpTypeStruct %6 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 4 +%9 = OpTypeArray %1 %8 +%10 = OpTypeArray %2 %8 +%11 = OpTypePointer Uniform %9 +%12 = OpTypePointer Uniform %10 +%3 = OpVariable %11 Uniform +%4 = OpVariable %12 Uniform +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +OpDecorate %4 DescriptorSet 1 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 4 +%9 = OpTypeArray %1 %8 +%11 = OpTypePointer Uniform %9 +%3 = OpVariable %11 Uniform +%4 = OpVariable %11 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// Test that we merge the type of a resource with a type that is not the type +// a resource. The resource type appears first in this case. We must keep +// the resource type. +TEST_F(RemoveDuplicatesTest, MergeResourceTypeWithNonresourceType1) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +%4 = OpTypeFloat 32 +%5 = OpTypeVector %4 3 +%1 = OpTypeStruct %5 +%2 = OpTypeStruct %5 +%6 = OpTypePointer Uniform %1 +%7 = OpTypePointer Uniform %2 +%3 = OpVariable %6 Uniform +%8 = OpVariable %7 Uniform +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +%4 = OpTypeFloat 32 +%5 = OpTypeVector %4 3 +%1 = OpTypeStruct %5 +%6 = OpTypePointer Uniform %1 +%3 = OpVariable %6 Uniform +%8 = OpVariable %6 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// Test that we merge the type of a resource with a type that is not the type +// a resource. The resource type appears second in this case. We must keep +// the resource type. +// +// See comment for DontMergeNestedResourceTypes. +TEST_F(RemoveDuplicatesTest, MergeResourceTypeWithNonresourceType2) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +%4 = OpTypeFloat 32 +%5 = OpTypeVector %4 3 +%1 = OpTypeStruct %5 +%2 = OpTypeStruct %5 +%6 = OpTypePointer Uniform %1 +%7 = OpTypePointer Uniform %2 +%8 = OpVariable %6 Uniform +%3 = OpVariable %7 Uniform +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +%4 = OpTypeFloat 32 +%5 = OpTypeVector %4 3 +%1 = OpTypeStruct %5 +%6 = OpTypePointer Uniform %1 +%8 = OpVariable %6 Uniform +%3 = OpVariable %6 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// In this test, %8 and %9 are the same and only %9 is used in a resource. +// However, we cannot merge them unless we also merge %2 and %3, which cannot +// happen because both are used in resources. +// +// If we try to avoid replaces resource types, then remove duplicates should +// have not change in this case. That is not currently implemented. +TEST_F(RemoveDuplicatesTest, MergeResourceTypeWithNonresourceType3) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "main" +OpSource HLSL 600 +OpName %2 "PositionAdjust" +OpMemberName %2 0 "XAdjust" +OpName %3 "NormalAdjust" +OpMemberName %3 0 "XDir" +OpName %4 "Constants" +OpMemberDecorate %2 0 Offset 0 +OpMemberDecorate %3 0 Offset 0 +OpDecorate %4 DescriptorSet 0 +OpDecorate %4 Binding 0 +OpDecorate %5 DescriptorSet 1 +OpDecorate %5 Binding 0 +%6 = OpTypeFloat 32 +%7 = OpTypeVector %6 3 +%2 = OpTypeStruct %7 +%3 = OpTypeStruct %7 +%8 = OpTypePointer Uniform %3 +%9 = OpTypePointer Uniform %2 +%10 = OpTypeStruct %3 +%11 = OpTypePointer Uniform %10 +%5 = OpVariable %9 Uniform +%4 = OpVariable %11 Uniform +%12 = OpTypeVoid +%13 = OpTypeFunction %12 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 0 +%1 = OpFunction %12 None %13 +%16 = OpLabel +%17 = OpAccessChain %8 %4 %15 +OpReturn +OpFunctionEnd +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "main" +OpSource HLSL 600 +OpName %2 "PositionAdjust" +OpMemberName %2 0 "XAdjust" +OpName %4 "Constants" +OpMemberDecorate %2 0 Offset 0 +OpDecorate %4 DescriptorSet 0 +OpDecorate %4 Binding 0 +OpDecorate %5 DescriptorSet 1 +OpDecorate %5 Binding 0 +%6 = OpTypeFloat 32 +%7 = OpTypeVector %6 3 +%2 = OpTypeStruct %7 +%8 = OpTypePointer Uniform %2 +%10 = OpTypeStruct %2 +%11 = OpTypePointer Uniform %10 +%5 = OpVariable %8 Uniform +%4 = OpVariable %11 Uniform +%12 = OpTypeVoid +%13 = OpTypeFunction %12 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 0 +%1 = OpFunction %12 None %13 +%16 = OpLabel +%17 = OpAccessChain %8 %4 %15 +OpReturn +OpFunctionEnd +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/pass_test.cpp b/3rdparty/spirv-tools/test/opt/pass_test.cpp index 5ff1a121d..bce05b679 100644 --- a/3rdparty/spirv-tools/test/opt/pass_test.cpp +++ b/3rdparty/spirv-tools/test/opt/pass_test.cpp @@ -12,32 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include -#include - -#include "assembly_builder.h" -#include "opt/pass.h" -#include "pass_fixture.h" -#include "pass_utils.h" +#include "gmock/gmock.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; -class DummyPass : public opt::Pass { + +class DummyPass : public Pass { public: const char* name() const override { return "dummy-pass"; } - Status Process(ir::IRContext* irContext) override { - return irContext ? Status::SuccessWithoutChange : Status::Failure; - } + Status Process() override { return Status::SuccessWithoutChange; } }; -} // namespace -namespace { - -using namespace spvtools; using ::testing::UnorderedElementsAre; - using PassClassTest = PassTest<::testing::Test>; TEST_F(PassClassTest, BasicVisitFromEntryPoint) { @@ -76,14 +71,14 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { )"; // clang-format on - std::unique_ptr localContext = + std::unique_ptr localContext = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" << text << std::endl; DummyPass testPass; std::vector processed; - opt::Pass::ProcessFunction mark_visited = [&processed](ir::Function* fp) { + Pass::ProcessFunction mark_visited = [&processed](Function* fp) { processed.push_back(fp->result_id()); return false; }; @@ -132,7 +127,7 @@ TEST_F(PassClassTest, BasicVisitReachable) { )"; // clang-format on - std::unique_ptr localContext = + std::unique_ptr localContext = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" @@ -140,7 +135,7 @@ TEST_F(PassClassTest, BasicVisitReachable) { DummyPass testPass; std::vector processed; - opt::Pass::ProcessFunction mark_visited = [&processed](ir::Function* fp) { + Pass::ProcessFunction mark_visited = [&processed](Function* fp) { processed.push_back(fp->result_id()); return false; }; @@ -184,7 +179,7 @@ TEST_F(PassClassTest, BasicVisitOnlyOnce) { )"; // clang-format on - std::unique_ptr localContext = + std::unique_ptr localContext = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" @@ -192,7 +187,7 @@ TEST_F(PassClassTest, BasicVisitOnlyOnce) { DummyPass testPass; std::vector processed; - opt::Pass::ProcessFunction mark_visited = [&processed](ir::Function* fp) { + Pass::ProcessFunction mark_visited = [&processed](Function* fp) { processed.push_back(fp->result_id()); return false; }; @@ -226,7 +221,7 @@ TEST_F(PassClassTest, BasicDontVisitExportedVariable) { )"; // clang-format on - std::unique_ptr localContext = + std::unique_ptr localContext = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" @@ -234,11 +229,14 @@ TEST_F(PassClassTest, BasicDontVisitExportedVariable) { DummyPass testPass; std::vector processed; - opt::Pass::ProcessFunction mark_visited = [&processed](ir::Function* fp) { + Pass::ProcessFunction mark_visited = [&processed](Function* fp) { processed.push_back(fp->result_id()); return false; }; testPass.ProcessReachableCallTree(mark_visited, localContext.get()); EXPECT_THAT(processed, UnorderedElementsAre(10)); } + } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/pass_utils.cpp b/3rdparty/spirv-tools/test/opt/pass_utils.cpp index 95b5181de..ceb999610 100644 --- a/3rdparty/spirv-tools/test/opt/pass_utils.cpp +++ b/3rdparty/spirv-tools/test/opt/pass_utils.cpp @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_utils.h" +#include "test/opt/pass_utils.h" #include #include +namespace spvtools { +namespace opt { namespace { // Well, this is another place requiring the knowledge of the grammar and can be @@ -33,8 +35,6 @@ const char* kDebugOpcodes[] = { } // anonymous namespace -namespace spvtools { - bool FindAndReplace(std::string* process_str, const std::string find_str, const std::string replace_str) { if (process_str->empty() || find_str.empty()) { @@ -78,4 +78,5 @@ std::string JoinNonDebugInsts(const std::vector& insts) { insts, [](const char* inst) { return ContainsDebugOpcode(inst); }); } +} // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/pass_utils.h b/3rdparty/spirv-tools/test/opt/pass_utils.h index 4f17700a0..37406842a 100644 --- a/3rdparty/spirv-tools/test/opt/pass_utils.h +++ b/3rdparty/spirv-tools/test/opt/pass_utils.h @@ -12,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_TEST_OPT_PASS_UTILS_H_ -#define LIBSPIRV_TEST_OPT_PASS_UTILS_H_ +#ifndef TEST_OPT_PASS_UTILS_H_ +#define TEST_OPT_PASS_UTILS_H_ +#include #include #include #include #include namespace spvtools { +namespace opt { // In-place substring replacement. Finds the |find_str| in the |process_str| // and replaces the found substring with |replace_str|. Returns true if at @@ -60,6 +62,7 @@ std::vector Concat(const std::vector& a, const std::vector& b) { return ret; } +} // namespace opt } // namespace spvtools -#endif // LIBSPIRV_TEST_OPT_PASS_UTILS_H_ +#endif // TEST_OPT_PASS_UTILS_H_ diff --git a/3rdparty/spirv-tools/test/opt/private_to_local_test.cpp b/3rdparty/spirv-tools/test/opt/private_to_local_test.cpp index d59711b61..d7eb37e51 100644 --- a/3rdparty/spirv-tools/test/opt/private_to_local_test.cpp +++ b/3rdparty/spirv-tools/test/opt/private_to_local_test.cpp @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "opt/value_number_table.h" +#include -#include "assembly_builder.h" #include "gmock/gmock.h" -#include "opt/build_module.h" -#include "pass_fixture.h" -#include "pass_utils.h" +#include "source/opt/build_module.h" +#include "source/opt/value_number_table.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using ::testing::HasSubstr; using ::testing::MatchesRegex; - using PrivateToLocalTest = PassTest<::testing::Test>; #ifdef SPIRV_EFFCEE @@ -57,7 +57,7 @@ TEST_F(PrivateToLocalTest, ChangeToLocal) { OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(PrivateToLocalTest, ReuseExistingType) { @@ -89,7 +89,7 @@ TEST_F(PrivateToLocalTest, ReuseExistingType) { OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(PrivateToLocalTest, UpdateAccessChain) { @@ -127,7 +127,7 @@ TEST_F(PrivateToLocalTest, UpdateAccessChain) { OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(PrivateToLocalTest, UseTexelPointer) { @@ -172,7 +172,7 @@ OpCapability SampledBuffer OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(PrivateToLocalTest, UsedInTwoFunctions) { @@ -200,9 +200,9 @@ TEST_F(PrivateToLocalTest, UsedInTwoFunctions) { OpReturn OpFunctionEnd )"; - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } TEST_F(PrivateToLocalTest, UsedInFunctionCall) { @@ -234,9 +234,83 @@ TEST_F(PrivateToLocalTest, UsedInFunctionCall) { OpReturn OpFunctionEnd )"; - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } + +TEST_F(PrivateToLocalTest, CreatePointerToAmbiguousStruct1) { + // Test that the correct pointer type is picked up. + const std::string text = R"( +; CHECK: [[struct1:%[a-zA-Z_\d]+]] = OpTypeStruct +; CHECK: [[struct2:%[a-zA-Z_\d]+]] = OpTypeStruct +; CHECK: [[priv_ptr:%[\w]+]] = OpTypePointer Private [[struct1]] +; CHECK: [[fuct_ptr2:%[\w]+]] = OpTypePointer Function [[struct2]] +; CHECK: [[fuct_ptr1:%[\w]+]] = OpTypePointer Function [[struct1]] +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK-NEXT: [[newvar:%[a-zA-Z_\d]+]] = OpVariable [[fuct_ptr1]] Function +; CHECK: OpLoad [[struct1]] [[newvar]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %struct1 = OpTypeStruct %5 + %struct2 = OpTypeStruct %5 + %6 = OpTypePointer Private %struct1 + %func_ptr2 = OpTypePointer Function %struct2 + %8 = OpVariable %6 Private + %2 = OpFunction %3 None %4 + %7 = OpLabel + %9 = OpLoad %struct1 %8 + OpReturn + OpFunctionEnd + )"; + SinglePassRunAndMatch(text, false); +} + +TEST_F(PrivateToLocalTest, CreatePointerToAmbiguousStruct2) { + // Test that the correct pointer type is picked up. + const std::string text = R"( +; CHECK: [[struct1:%[a-zA-Z_\d]+]] = OpTypeStruct +; CHECK: [[struct2:%[a-zA-Z_\d]+]] = OpTypeStruct +; CHECK: [[priv_ptr:%[\w]+]] = OpTypePointer Private [[struct2]] +; CHECK: [[fuct_ptr1:%[\w]+]] = OpTypePointer Function [[struct1]] +; CHECK: [[fuct_ptr2:%[\w]+]] = OpTypePointer Function [[struct2]] +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK-NEXT: [[newvar:%[a-zA-Z_\d]+]] = OpVariable [[fuct_ptr2]] Function +; CHECK: OpLoad [[struct2]] [[newvar]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %struct1 = OpTypeStruct %5 + %struct2 = OpTypeStruct %5 + %6 = OpTypePointer Private %struct2 + %func_ptr2 = OpTypePointer Function %struct1 + %8 = OpVariable %6 Private + %2 = OpFunction %3 None %4 + %7 = OpLabel + %9 = OpLoad %struct2 %8 + OpReturn + OpFunctionEnd + )"; + SinglePassRunAndMatch(text, false); +} + #endif -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/propagator_test.cpp b/3rdparty/spirv-tools/test/opt/propagator_test.cpp index b6cd2c4c0..fb8e487cc 100644 --- a/3rdparty/spirv-tools/test/opt/propagator_test.cpp +++ b/3rdparty/spirv-tools/test/opt/propagator_test.cpp @@ -12,19 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include +#include +#include +#include +#include -#include "opt/build_module.h" -#include "opt/cfg.h" -#include "opt/ir_context.h" -#include "opt/pass.h" -#include "opt/propagator.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/cfg.h" +#include "source/opt/ir_context.h" +#include "source/opt/pass.h" +#include "source/opt/propagator.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using ::testing::UnorderedElementsAre; class PropagatorTest : public testing::Test { @@ -41,8 +45,8 @@ class PropagatorTest : public testing::Test { << input << "\n"; } - bool Propagate(const opt::SSAPropagator::VisitFunction& visit_fn) { - opt::SSAPropagator propagator(ctx_.get(), visit_fn); + bool Propagate(const SSAPropagator::VisitFunction& visit_fn) { + SSAPropagator propagator(ctx_.get(), visit_fn); bool retval = false; for (auto& fn : *ctx_->module()) { retval |= propagator.Run(&fn); @@ -58,7 +62,7 @@ class PropagatorTest : public testing::Test { return values_vec_; } - std::unique_ptr ctx_; + std::unique_ptr ctx_; std::map values_; std::vector values_vec_; }; @@ -101,20 +105,19 @@ TEST_F(PropagatorTest, LocalPropagate) { )"; Assemble(spv_asm); - const auto visit_fn = [this](ir::Instruction* instr, - ir::BasicBlock** dest_bb) { + const auto visit_fn = [this](Instruction* instr, BasicBlock** dest_bb) { *dest_bb = nullptr; if (instr->opcode() == SpvOpStore) { uint32_t lhs_id = instr->GetSingleWordOperand(0); uint32_t rhs_id = instr->GetSingleWordOperand(1); - ir::Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id); + Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id); if (rhs_def->opcode() == SpvOpConstant) { uint32_t val = rhs_def->GetSingleWordOperand(2); values_[lhs_id] = val; - return opt::SSAPropagator::kInteresting; + return SSAPropagator::kInteresting; } } - return opt::SSAPropagator::kVarying; + return SSAPropagator::kVarying; }; EXPECT_TRUE(Propagate(visit_fn)); @@ -168,37 +171,37 @@ TEST_F(PropagatorTest, PropagateThroughPhis) { Assemble(spv_asm); - ir::Instruction *phi_instr = nullptr; - const auto visit_fn = [this, &phi_instr](ir::Instruction* instr, - ir::BasicBlock** dest_bb) { + Instruction* phi_instr = nullptr; + const auto visit_fn = [this, &phi_instr](Instruction* instr, + BasicBlock** dest_bb) { *dest_bb = nullptr; if (instr->opcode() == SpvOpLoad) { uint32_t rhs_id = instr->GetSingleWordOperand(2); - ir::Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id); + Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id); if (rhs_def->opcode() == SpvOpConstant) { uint32_t val = rhs_def->GetSingleWordOperand(2); values_[instr->result_id()] = val; - return opt::SSAPropagator::kInteresting; + return SSAPropagator::kInteresting; } } else if (instr->opcode() == SpvOpPhi) { phi_instr = instr; - opt::SSAPropagator::PropStatus retval; + SSAPropagator::PropStatus retval; for (uint32_t i = 2; i < instr->NumOperands(); i += 2) { uint32_t phi_arg_id = instr->GetSingleWordOperand(i); auto it = values_.find(phi_arg_id); if (it != values_.end()) { EXPECT_EQ(it->second, 4u); - retval = opt::SSAPropagator::kInteresting; + retval = SSAPropagator::kInteresting; values_[instr->result_id()] = it->second; } else { - retval = opt::SSAPropagator::kNotInteresting; + retval = SSAPropagator::kNotInteresting; break; } } return retval; } - return opt::SSAPropagator::kVarying; + return SSAPropagator::kVarying; }; EXPECT_TRUE(Propagate(visit_fn)); @@ -212,3 +215,5 @@ TEST_F(PropagatorTest, PropagateThroughPhis) { } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/reduce_load_size_test.cpp b/3rdparty/spirv-tools/test/opt/reduce_load_size_test.cpp new file mode 100644 index 000000000..1d367e101 --- /dev/null +++ b/3rdparty/spirv-tools/test/opt/reduce_load_size_test.cpp @@ -0,0 +1,328 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ReduceLoadSizeTest = PassTest<::testing::Test>; + +#ifdef SPIRV_EFFCEE +TEST_F(ReduceLoadSizeTest, cbuffer_load_extract) { + // Originally from the following HLSL: + // struct S { + // uint f; + // }; + // + // + // cbuffer gBuffer { uint a[32]; }; + // + // RWStructuredBuffer gRWSBuffer; + // + // uint foo(uint p[32]) { + // return p[1]; + // } + // + // [numthreads(1,1,1)] + // void main() { + // gRWSBuffer[0].f = foo(a); + // } + const std::string test = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource HLSL 600 + OpName %type_gBuffer "type.gBuffer" + OpMemberName %type_gBuffer 0 "a" + OpName %gBuffer "gBuffer" + OpName %S "S" + OpMemberName %S 0 "f" + OpName %type_RWStructuredBuffer_S "type.RWStructuredBuffer.S" + OpName %gRWSBuffer "gRWSBuffer" + OpName %main "main" + OpDecorate %_arr_uint_uint_32 ArrayStride 16 + OpMemberDecorate %type_gBuffer 0 Offset 0 + OpDecorate %type_gBuffer Block + OpMemberDecorate %S 0 Offset 0 + OpDecorate %_runtimearr_S ArrayStride 4 + OpMemberDecorate %type_RWStructuredBuffer_S 0 Offset 0 + OpDecorate %type_RWStructuredBuffer_S BufferBlock + OpDecorate %gBuffer DescriptorSet 0 + OpDecorate %gBuffer Binding 0 + OpDecorate %gRWSBuffer DescriptorSet 0 + OpDecorate %gRWSBuffer Binding 1 + %uint = OpTypeInt 32 0 + %uint_32 = OpConstant %uint 32 +%_arr_uint_uint_32 = OpTypeArray %uint %uint_32 +%type_gBuffer = OpTypeStruct %_arr_uint_uint_32 +%_ptr_Uniform_type_gBuffer = OpTypePointer Uniform %type_gBuffer + %S = OpTypeStruct %uint +%_runtimearr_S = OpTypeRuntimeArray %S +%type_RWStructuredBuffer_S = OpTypeStruct %_runtimearr_S +%_ptr_Uniform_type_RWStructuredBuffer_S = OpTypePointer Uniform %type_RWStructuredBuffer_S + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %15 = OpTypeFunction %void + %int_0 = OpConstant %int 0 +%_ptr_Uniform__arr_uint_uint_32 = OpTypePointer Uniform %_arr_uint_uint_32 + %uint_0 = OpConstant %uint 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint + %gBuffer = OpVariable %_ptr_Uniform_type_gBuffer Uniform + %gRWSBuffer = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_S Uniform + %main = OpFunction %void None %15 + %20 = OpLabel +; CHECK: [[ac1:%\w+]] = OpAccessChain {{%\w+}} %gBuffer %int_0 +; CHECK: [[ac2:%\w+]] = OpAccessChain {{%\w+}} [[ac1]] %uint_1 +; CHECK: [[ld:%\w+]] = OpLoad {{%\w+}} [[ac2]] +; CHECK: OpStore {{%\w+}} [[ld]] + %21 = OpAccessChain %_ptr_Uniform__arr_uint_uint_32 %gBuffer %int_0 + %22 = OpLoad %_arr_uint_uint_32 %21 ; Load of 32-element array. + %23 = OpCompositeExtract %uint %22 1 + %24 = OpAccessChain %_ptr_Uniform_uint %gRWSBuffer %int_0 %uint_0 %int_0 + OpStore %24 %23 + OpReturn + OpFunctionEnd + )"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndMatch(test, false); +} +#endif + +TEST_F(ReduceLoadSizeTest, cbuffer_load_extract_vector) { + // Originally from the following HLSL: + // struct S { + // uint f; + // }; + // + // + // cbuffer gBuffer { uint4 a; }; + // + // RWStructuredBuffer gRWSBuffer; + // + // uint foo(uint p[32]) { + // return p[1]; + // } + // + // [numthreads(1,1,1)] + // void main() { + // gRWSBuffer[0].f = foo(a); + // } + const std::string test = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +OpSource HLSL 600 +OpName %type_gBuffer "type.gBuffer" +OpMemberName %type_gBuffer 0 "a" +OpName %gBuffer "gBuffer" +OpName %S "S" +OpMemberName %S 0 "f" +OpName %type_RWStructuredBuffer_S "type.RWStructuredBuffer.S" +OpName %gRWSBuffer "gRWSBuffer" +OpName %main "main" +OpMemberDecorate %type_gBuffer 0 Offset 0 +OpDecorate %type_gBuffer Block +OpMemberDecorate %S 0 Offset 0 +OpDecorate %_runtimearr_S ArrayStride 4 +OpMemberDecorate %type_RWStructuredBuffer_S 0 Offset 0 +OpDecorate %type_RWStructuredBuffer_S BufferBlock +OpDecorate %gBuffer DescriptorSet 0 +OpDecorate %gBuffer Binding 0 +OpDecorate %gRWSBuffer DescriptorSet 0 +OpDecorate %gRWSBuffer Binding 1 +%uint = OpTypeInt 32 0 +%uint_32 = OpConstant %uint 32 +%v4uint = OpTypeVector %uint 4 +%type_gBuffer = OpTypeStruct %v4uint +%_ptr_Uniform_type_gBuffer = OpTypePointer Uniform %type_gBuffer +%S = OpTypeStruct %uint +%_runtimearr_S = OpTypeRuntimeArray %S +%type_RWStructuredBuffer_S = OpTypeStruct %_runtimearr_S +%_ptr_Uniform_type_RWStructuredBuffer_S = OpTypePointer Uniform %type_RWStructuredBuffer_S +%int = OpTypeInt 32 1 +%void = OpTypeVoid +%15 = OpTypeFunction %void +%int_0 = OpConstant %int 0 +%_ptr_Uniform_v4uint = OpTypePointer Uniform %v4uint +%uint_0 = OpConstant %uint 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%gBuffer = OpVariable %_ptr_Uniform_type_gBuffer Uniform +%gRWSBuffer = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_S Uniform +%main = OpFunction %void None %15 +%20 = OpLabel +%21 = OpAccessChain %_ptr_Uniform_v4uint %gBuffer %int_0 +%22 = OpLoad %v4uint %21 +%23 = OpCompositeExtract %uint %22 1 +%24 = OpAccessChain %_ptr_Uniform_uint %gRWSBuffer %int_0 %uint_0 %int_0 +OpStore %24 %23 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndCheck(test, test, true, false); +} + +TEST_F(ReduceLoadSizeTest, cbuffer_load_5_extract) { + // All of the elements of the value loaded are used, so we should not + // change the load. + const std::string test = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +OpSource HLSL 600 +OpName %type_gBuffer "type.gBuffer" +OpMemberName %type_gBuffer 0 "a" +OpName %gBuffer "gBuffer" +OpName %S "S" +OpMemberName %S 0 "f" +OpName %type_RWStructuredBuffer_S "type.RWStructuredBuffer.S" +OpName %gRWSBuffer "gRWSBuffer" +OpName %main "main" +OpDecorate %_arr_uint_uint_5 ArrayStride 16 +OpMemberDecorate %type_gBuffer 0 Offset 0 +OpDecorate %type_gBuffer Block +OpMemberDecorate %S 0 Offset 0 +OpDecorate %_runtimearr_S ArrayStride 4 +OpMemberDecorate %type_RWStructuredBuffer_S 0 Offset 0 +OpDecorate %type_RWStructuredBuffer_S BufferBlock +OpDecorate %gBuffer DescriptorSet 0 +OpDecorate %gBuffer Binding 0 +OpDecorate %gRWSBuffer DescriptorSet 0 +OpDecorate %gRWSBuffer Binding 1 +%uint = OpTypeInt 32 0 +%uint_5 = OpConstant %uint 5 +%_arr_uint_uint_5 = OpTypeArray %uint %uint_5 +%type_gBuffer = OpTypeStruct %_arr_uint_uint_5 +%_ptr_Uniform_type_gBuffer = OpTypePointer Uniform %type_gBuffer +%S = OpTypeStruct %uint +%_runtimearr_S = OpTypeRuntimeArray %S +%type_RWStructuredBuffer_S = OpTypeStruct %_runtimearr_S +%_ptr_Uniform_type_RWStructuredBuffer_S = OpTypePointer Uniform %type_RWStructuredBuffer_S +%int = OpTypeInt 32 1 +%void = OpTypeVoid +%15 = OpTypeFunction %void +%int_0 = OpConstant %int 0 +%_ptr_Uniform__arr_uint_uint_5 = OpTypePointer Uniform %_arr_uint_uint_5 +%uint_0 = OpConstant %uint 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%gBuffer = OpVariable %_ptr_Uniform_type_gBuffer Uniform +%gRWSBuffer = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_S Uniform +%main = OpFunction %void None %15 +%20 = OpLabel +%21 = OpAccessChain %_ptr_Uniform__arr_uint_uint_5 %gBuffer %int_0 +%22 = OpLoad %_arr_uint_uint_5 %21 +%23 = OpCompositeExtract %uint %22 0 +%24 = OpCompositeExtract %uint %22 1 +%25 = OpCompositeExtract %uint %22 2 +%26 = OpCompositeExtract %uint %22 3 +%27 = OpCompositeExtract %uint %22 4 +%28 = OpIAdd %uint %23 %24 +%29 = OpIAdd %uint %28 %25 +%30 = OpIAdd %uint %29 %26 +%31 = OpIAdd %uint %20 %27 +%32 = OpAccessChain %_ptr_Uniform_uint %gRWSBuffer %int_0 %uint_0 %int_0 +OpStore %32 %31 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndCheck(test, test, true, false); +} + +TEST_F(ReduceLoadSizeTest, cbuffer_load_fully_used) { + // The result of the load (%22) is used in an instruction that uses the whole + // load and has only 1 in operand. This trigger issue #1559. + const std::string test = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +OpSource HLSL 600 +OpName %type_gBuffer "type.gBuffer" +OpMemberName %type_gBuffer 0 "a" +OpName %gBuffer "gBuffer" +OpName %S "S" +OpMemberName %S 0 "f" +OpName %type_RWStructuredBuffer_S "type.RWStructuredBuffer.S" +OpName %gRWSBuffer "gRWSBuffer" +OpName %main "main" +OpMemberDecorate %type_gBuffer 0 Offset 0 +OpDecorate %type_gBuffer Block +OpMemberDecorate %S 0 Offset 0 +OpDecorate %_runtimearr_S ArrayStride 4 +OpMemberDecorate %type_RWStructuredBuffer_S 0 Offset 0 +OpDecorate %type_RWStructuredBuffer_S BufferBlock +OpDecorate %gBuffer DescriptorSet 0 +OpDecorate %gBuffer Binding 0 +OpDecorate %gRWSBuffer DescriptorSet 0 +OpDecorate %gRWSBuffer Binding 1 +%uint = OpTypeInt 32 0 +%uint_32 = OpConstant %uint 32 +%v4uint = OpTypeVector %uint 4 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%type_gBuffer = OpTypeStruct %v4uint +%_ptr_Uniform_type_gBuffer = OpTypePointer Uniform %type_gBuffer +%S = OpTypeStruct %uint +%_runtimearr_S = OpTypeRuntimeArray %S +%type_RWStructuredBuffer_S = OpTypeStruct %_runtimearr_S +%_ptr_Uniform_type_RWStructuredBuffer_S = OpTypePointer Uniform %type_RWStructuredBuffer_S +%int = OpTypeInt 32 1 +%void = OpTypeVoid +%15 = OpTypeFunction %void +%int_0 = OpConstant %int 0 +%_ptr_Uniform_v4uint = OpTypePointer Uniform %v4uint +%uint_0 = OpConstant %uint 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%gBuffer = OpVariable %_ptr_Uniform_type_gBuffer Uniform +%gRWSBuffer = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_S Uniform +%main = OpFunction %void None %15 +%20 = OpLabel +%21 = OpAccessChain %_ptr_Uniform_v4uint %gBuffer %int_0 +%22 = OpLoad %v4uint %21 +%23 = OpCompositeExtract %uint %22 1 +%24 = OpConvertUToF %v4float %22 +%25 = OpAccessChain %_ptr_Uniform_uint %gRWSBuffer %int_0 %uint_0 %int_0 +OpStore %25 %23 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndCheck(test, test, true, false); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/redundancy_elimination_test.cpp b/3rdparty/spirv-tools/test/opt/redundancy_elimination_test.cpp index c5f386818..a6e8c4f28 100644 --- a/3rdparty/spirv-tools/test/opt/redundancy_elimination_test.cpp +++ b/3rdparty/spirv-tools/test/opt/redundancy_elimination_test.cpp @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "opt/value_number_table.h" +#include -#include "assembly_builder.h" #include "gmock/gmock.h" -#include "opt/build_module.h" -#include "pass_fixture.h" -#include "pass_utils.h" +#include "source/opt/build_module.h" +#include "source/opt/value_number_table.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using ::testing::HasSubstr; using ::testing::MatchesRegex; - using RedundancyEliminationTest = PassTest<::testing::Test>; #ifdef SPIRV_EFFCEE @@ -55,7 +55,7 @@ TEST_F(RedundancyEliminationTest, RemoveRedundantLocalAdd) { OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } // Remove a redundant add across basic blocks. @@ -84,7 +84,7 @@ TEST_F(RedundancyEliminationTest, RemoveRedundantAdd) { OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } // Remove a redundant add going through a multiple basic blocks. @@ -120,7 +120,7 @@ TEST_F(RedundancyEliminationTest, RemoveRedundantAddDiamond) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } // Remove a redundant add in a side node. @@ -156,7 +156,7 @@ TEST_F(RedundancyEliminationTest, RemoveRedundantAddInSideNode) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } // Remove a redundant add whose value is in the result of a phi node. @@ -196,7 +196,7 @@ TEST_F(RedundancyEliminationTest, RemoveRedundantAddWithPhi) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } // Keep the add because it is redundant on some paths, but not all paths. @@ -230,9 +230,9 @@ TEST_F(RedundancyEliminationTest, KeepPartiallyRedundantAdd) { OpFunctionEnd )"; - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } // Keep the add. Even if it is redundant on all paths, there is no single id @@ -268,9 +268,13 @@ TEST_F(RedundancyEliminationTest, KeepRedundantAddWithoutPhi) { OpFunctionEnd )"; - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } + #endif -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/register_liveness.cpp b/3rdparty/spirv-tools/test/opt/register_liveness.cpp new file mode 100644 index 000000000..cb973d2e6 --- /dev/null +++ b/3rdparty/spirv-tools/test/opt/register_liveness.cpp @@ -0,0 +1,1282 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/register_pressure.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +void CompareSets(const std::unordered_set& computed, + const std::unordered_set& expected) { + for (Instruction* insn : computed) { + EXPECT_TRUE(expected.count(insn->result_id())) + << "Unexpected instruction in live set: " << *insn; + } + EXPECT_EQ(computed.size(), expected.size()); +} + +/* +Generated from the following GLSL + +#version 330 +in vec4 BaseColor; +flat in int Count; +void main() +{ + vec4 color = BaseColor; + vec4 acc; + if (Count == 0) { + acc = color; + } + else { + acc = color + vec4(0,1,2,0); + } + gl_FragColor = acc + color; +} +*/ +TEST_F(PassClassTest, LivenessWithIf) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %11 %15 %32 + OpExecutionMode %4 OriginLowerLeft + OpSource GLSL 330 + OpName %4 "main" + OpName %11 "BaseColor" + OpName %15 "Count" + OpName %32 "gl_FragColor" + OpDecorate %11 Location 0 + OpDecorate %15 Flat + OpDecorate %15 Location 0 + OpDecorate %32 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %10 = OpTypePointer Input %7 + %11 = OpVariable %10 Input + %13 = OpTypeInt 32 1 + %14 = OpTypePointer Input %13 + %15 = OpVariable %14 Input + %17 = OpConstant %13 0 + %18 = OpTypeBool + %26 = OpConstant %6 0 + %27 = OpConstant %6 1 + %28 = OpConstant %6 2 + %29 = OpConstantComposite %7 %26 %27 %28 %26 + %31 = OpTypePointer Output %7 + %32 = OpVariable %31 Output + %4 = OpFunction %2 None %3 + %5 = OpLabel + %12 = OpLoad %7 %11 + %16 = OpLoad %13 %15 + %19 = OpIEqual %18 %16 %17 + OpSelectionMerge %21 None + OpBranchConditional %19 %20 %24 + %20 = OpLabel + OpBranch %21 + %24 = OpLabel + %30 = OpFAdd %7 %12 %29 + OpBranch %21 + %21 = OpLabel + %36 = OpPhi %7 %12 %20 %30 %24 + %35 = OpFAdd %7 %36 %12 + OpStore %32 %35 + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function* f = &*module->begin(); + LivenessAnalysis* liveness_analysis = context->GetLivenessAnalysis(); + const RegisterLiveness* register_liveness = liveness_analysis->Get(f); + { + SCOPED_TRACE("Block 5"); + auto live_sets = register_liveness->Get(5); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 15, // %15 = OpVariable %14 Input + 32, // %32 = OpVariable %31 Output + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 12, // %12 = OpLoad %7 %11 + 32, // %32 = OpVariable %31 Output + }; + CompareSets(live_sets->live_out_, live_out); + } + { + SCOPED_TRACE("Block 20"); + auto live_sets = register_liveness->Get(20); + std::unordered_set live_inout{ + 12, // %12 = OpLoad %7 %11 + 32, // %32 = OpVariable %31 Output + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + } + { + SCOPED_TRACE("Block 24"); + auto live_sets = register_liveness->Get(24); + std::unordered_set live_in{ + 12, // %12 = OpLoad %7 %11 + 32, // %32 = OpVariable %31 Output + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 12, // %12 = OpLoad %7 %11 + 30, // %30 = OpFAdd %7 %12 %29 + 32, // %32 = OpVariable %31 Output + }; + CompareSets(live_sets->live_out_, live_out); + } + { + SCOPED_TRACE("Block 21"); + auto live_sets = register_liveness->Get(21); + std::unordered_set live_in{ + 12, // %12 = OpLoad %7 %11 + 32, // %32 = OpVariable %31 Output + 36, // %36 = OpPhi %7 %12 %20 %30 %24 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{}; + CompareSets(live_sets->live_out_, live_out); + } +} + +/* +Generated from the following GLSL +#version 330 +in vec4 bigColor; +in vec4 BaseColor; +in float f; +flat in int Count; +flat in uvec4 v4; +void main() +{ + vec4 color = BaseColor; + for (int i = 0; i < Count; ++i) + color += bigColor; + float sum = 0.0; + for (int i = 0; i < 4; ++i) { + float acc = 0.0; + if (sum == 0.0) { + acc = v4[i]; + } + else { + acc = BaseColor[i]; + } + sum += acc + v4[i]; + } + vec4 tv4; + for (int i = 0; i < 4; ++i) + tv4[i] = v4[i] * 4u; + color += vec4(sum) + tv4; + vec4 r; + r.xyz = BaseColor.xyz; + for (int i = 0; i < Count; ++i) + r.w = f; + color.xyz += r.xyz; + for (int i = 0; i < 16; i += 4) + for (int j = 0; j < 4; j++) + color *= f; + gl_FragColor = color + tv4; +} +*/ +TEST_F(PassClassTest, RegisterLiveness) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %11 %24 %28 %55 %124 %176 + OpExecutionMode %4 OriginLowerLeft + OpSource GLSL 330 + OpName %4 "main" + OpName %11 "BaseColor" + OpName %24 "Count" + OpName %28 "bigColor" + OpName %55 "v4" + OpName %84 "tv4" + OpName %124 "f" + OpName %176 "gl_FragColor" + OpDecorate %11 Location 0 + OpDecorate %24 Flat + OpDecorate %24 Location 0 + OpDecorate %28 Location 0 + OpDecorate %55 Flat + OpDecorate %55 Location 0 + OpDecorate %124 Location 0 + OpDecorate %176 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %8 = OpTypePointer Function %7 + %10 = OpTypePointer Input %7 + %11 = OpVariable %10 Input + %13 = OpTypeInt 32 1 + %16 = OpConstant %13 0 + %23 = OpTypePointer Input %13 + %24 = OpVariable %23 Input + %26 = OpTypeBool + %28 = OpVariable %10 Input + %33 = OpConstant %13 1 + %35 = OpTypePointer Function %6 + %37 = OpConstant %6 0 + %45 = OpConstant %13 4 + %52 = OpTypeInt 32 0 + %53 = OpTypeVector %52 4 + %54 = OpTypePointer Input %53 + %55 = OpVariable %54 Input + %57 = OpTypePointer Input %52 + %63 = OpTypePointer Input %6 + %89 = OpConstant %52 4 + %102 = OpTypeVector %6 3 + %124 = OpVariable %63 Input + %158 = OpConstant %13 16 + %175 = OpTypePointer Output %7 + %176 = OpVariable %175 Output + %195 = OpUndef %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %84 = OpVariable %8 Function + %12 = OpLoad %7 %11 + OpBranch %17 + %17 = OpLabel + %191 = OpPhi %7 %12 %5 %31 %18 + %184 = OpPhi %13 %16 %5 %34 %18 + %25 = OpLoad %13 %24 + %27 = OpSLessThan %26 %184 %25 + OpLoopMerge %19 %18 None + OpBranchConditional %27 %18 %19 + %18 = OpLabel + %29 = OpLoad %7 %28 + %31 = OpFAdd %7 %191 %29 + %34 = OpIAdd %13 %184 %33 + OpBranch %17 + %19 = OpLabel + OpBranch %39 + %39 = OpLabel + %188 = OpPhi %6 %37 %19 %73 %51 + %185 = OpPhi %13 %16 %19 %75 %51 + %46 = OpSLessThan %26 %185 %45 + OpLoopMerge %41 %51 None + OpBranchConditional %46 %40 %41 + %40 = OpLabel + %49 = OpFOrdEqual %26 %188 %37 + OpSelectionMerge %51 None + OpBranchConditional %49 %50 %61 + %50 = OpLabel + %58 = OpAccessChain %57 %55 %185 + %59 = OpLoad %52 %58 + %60 = OpConvertUToF %6 %59 + OpBranch %51 + %61 = OpLabel + %64 = OpAccessChain %63 %11 %185 + %65 = OpLoad %6 %64 + OpBranch %51 + %51 = OpLabel + %210 = OpPhi %6 %60 %50 %65 %61 + %68 = OpAccessChain %57 %55 %185 + %69 = OpLoad %52 %68 + %70 = OpConvertUToF %6 %69 + %71 = OpFAdd %6 %210 %70 + %73 = OpFAdd %6 %188 %71 + %75 = OpIAdd %13 %185 %33 + OpBranch %39 + %41 = OpLabel + OpBranch %77 + %77 = OpLabel + %186 = OpPhi %13 %16 %41 %94 %78 + %83 = OpSLessThan %26 %186 %45 + OpLoopMerge %79 %78 None + OpBranchConditional %83 %78 %79 + %78 = OpLabel + %87 = OpAccessChain %57 %55 %186 + %88 = OpLoad %52 %87 + %90 = OpIMul %52 %88 %89 + %91 = OpConvertUToF %6 %90 + %92 = OpAccessChain %35 %84 %186 + OpStore %92 %91 + %94 = OpIAdd %13 %186 %33 + OpBranch %77 + %79 = OpLabel + %96 = OpCompositeConstruct %7 %188 %188 %188 %188 + %97 = OpLoad %7 %84 + %98 = OpFAdd %7 %96 %97 + %100 = OpFAdd %7 %191 %98 + %104 = OpVectorShuffle %102 %12 %12 0 1 2 + %106 = OpVectorShuffle %7 %195 %104 4 5 6 3 + OpBranch %108 + %108 = OpLabel + %197 = OpPhi %7 %106 %79 %208 %133 + %196 = OpPhi %13 %16 %79 %143 %133 + %115 = OpSLessThan %26 %196 %25 + OpLoopMerge %110 %133 None + OpBranchConditional %115 %109 %110 + %109 = OpLabel + OpBranch %117 + %117 = OpLabel + %209 = OpPhi %7 %197 %109 %181 %118 + %204 = OpPhi %13 %16 %109 %129 %118 + %123 = OpSLessThan %26 %204 %45 + OpLoopMerge %119 %118 None + OpBranchConditional %123 %118 %119 + %118 = OpLabel + %125 = OpLoad %6 %124 + %181 = OpCompositeInsert %7 %125 %209 3 + %129 = OpIAdd %13 %204 %33 + OpBranch %117 + %119 = OpLabel + OpBranch %131 + %131 = OpLabel + %208 = OpPhi %7 %209 %119 %183 %132 + %205 = OpPhi %13 %16 %119 %141 %132 + %137 = OpSLessThan %26 %205 %45 + OpLoopMerge %133 %132 None + OpBranchConditional %137 %132 %133 + %132 = OpLabel + %138 = OpLoad %6 %124 + %183 = OpCompositeInsert %7 %138 %208 3 + %141 = OpIAdd %13 %205 %33 + OpBranch %131 + %133 = OpLabel + %143 = OpIAdd %13 %196 %33 + OpBranch %108 + %110 = OpLabel + %145 = OpVectorShuffle %102 %197 %197 0 1 2 + %147 = OpVectorShuffle %102 %100 %100 0 1 2 + %148 = OpFAdd %102 %147 %145 + %150 = OpVectorShuffle %7 %100 %148 4 5 6 3 + OpBranch %152 + %152 = OpLabel + %200 = OpPhi %7 %150 %110 %203 %163 + %199 = OpPhi %13 %16 %110 %174 %163 + %159 = OpSLessThan %26 %199 %158 + OpLoopMerge %154 %163 None + OpBranchConditional %159 %153 %154 + %153 = OpLabel + OpBranch %161 + %161 = OpLabel + %203 = OpPhi %7 %200 %153 %170 %162 + %201 = OpPhi %13 %16 %153 %172 %162 + %167 = OpSLessThan %26 %201 %45 + OpLoopMerge %163 %162 None + OpBranchConditional %167 %162 %163 + %162 = OpLabel + %168 = OpLoad %6 %124 + %170 = OpVectorTimesScalar %7 %203 %168 + %172 = OpIAdd %13 %201 %33 + OpBranch %161 + %163 = OpLabel + %174 = OpIAdd %13 %199 %45 + OpBranch %152 + %154 = OpLabel + %178 = OpLoad %7 %84 + %179 = OpFAdd %7 %200 %178 + OpStore %176 %179 + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function* f = &*module->begin(); + LivenessAnalysis* liveness_analysis = context->GetLivenessAnalysis(); + const RegisterLiveness* register_liveness = liveness_analysis->Get(f); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + { + SCOPED_TRACE("Block 5"); + auto live_sets = register_liveness->Get(5); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 24, // %24 = OpVariable %23 Input + 28, // %28 = OpVariable %10 Input + 55, // %55 = OpVariable %54 Input + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 24, // %24 = OpVariable %23 Input + 28, // %28 = OpVariable %10 Input + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 8u); + } + { + SCOPED_TRACE("Block 17"); + auto live_sets = register_liveness->Get(17); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 24, // %24 = OpVariable %23 Input + 28, // %28 = OpVariable %10 Input + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 184, // %184 = OpPhi %13 %16 %5 %34 %18 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 28, // %28 = OpVariable %10 Input + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 184, // %184 = OpPhi %13 %16 %5 %34 %18 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 11u); + } + { + SCOPED_TRACE("Block 18"); + auto live_sets = register_liveness->Get(18); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 24, // %24 = OpVariable %23 Input + 28, // %28 = OpVariable %10 Input + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 184, // %184 = OpPhi %13 %16 %5 %34 %18 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 24, // %24 = OpVariable %23 Input + 28, // %28 = OpVariable %10 Input + 31, // %31 = OpFAdd %7 %191 %29 + 34, // %34 = OpIAdd %13 %184 %33 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 12u); + } + { + SCOPED_TRACE("Block 19"); + auto live_sets = register_liveness->Get(19); + std::unordered_set live_inout{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 8u); + } + { + SCOPED_TRACE("Block 39"); + auto live_sets = register_liveness->Get(39); + std::unordered_set live_inout{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 11u); + } + { + SCOPED_TRACE("Block 40"); + auto live_sets = register_liveness->Get(40); + std::unordered_set live_inout{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 11u); + } + { + SCOPED_TRACE("Block 50"); + auto live_sets = register_liveness->Get(50); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 60, // %60 = OpConvertUToF %6 %59 + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 12u); + } + { + SCOPED_TRACE("Block 61"); + auto live_sets = register_liveness->Get(61); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 65, // %65 = OpLoad %6 %64 + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 12u); + } + { + SCOPED_TRACE("Block 51"); + auto live_sets = register_liveness->Get(51); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + 210, // %210 = OpPhi %6 %60 %50 %65 %61 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 73, // %73 = OpFAdd %6 %188 %71 + 75, // %75 = OpIAdd %13 %185 %33 + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 13u); + } + { + SCOPED_TRACE("Block 41"); + auto live_sets = register_liveness->Get(41); + std::unordered_set live_inout{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 8u); + } + { + SCOPED_TRACE("Block 77"); + auto live_sets = register_liveness->Get(77); + std::unordered_set live_inout{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 186, // %186 = OpPhi %13 %16 %41 %94 %78 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 10u); + } + { + SCOPED_TRACE("Block 78"); + auto live_sets = register_liveness->Get(78); + std::unordered_set live_in{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 186, // %186 = OpPhi %13 %16 %41 %94 %78 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 94, // %94 = OpIAdd %13 %186 %33 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 11u); + } + { + SCOPED_TRACE("Block 79"); + auto live_sets = register_liveness->Get(79); + std::unordered_set live_in{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 106, // %106 = OpVectorShuffle %7 %195 %104 4 5 6 3 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 9u); + } + { + SCOPED_TRACE("Block 108"); + auto live_sets = register_liveness->Get(108); + std::unordered_set live_in{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 197, // %197 = OpPhi %7 %106 %79 %208 %133 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 197, // %197 = OpPhi %7 %106 %79 %208 %133 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 8u); + } + { + SCOPED_TRACE("Block 109"); + auto live_sets = register_liveness->Get(109); + std::unordered_set live_inout{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 197, // %197 = OpPhi %7 %106 %79 %208 %133 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 7u); + } + { + SCOPED_TRACE("Block 117"); + auto live_sets = register_liveness->Get(117); + std::unordered_set live_inout{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 204, // %204 = OpPhi %13 %16 %109 %129 %118 + 209, // %209 = OpPhi %7 %197 %109 %181 %118 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 9u); + } + { + SCOPED_TRACE("Block 118"); + auto live_sets = register_liveness->Get(118); + std::unordered_set live_in{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 204, // %204 = OpPhi %13 %16 %109 %129 %118 + 209, // %209 = OpPhi %7 %197 %109 %181 %118 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 129, // %129 = OpIAdd %13 %204 %33 + 176, // %176 = OpVariable %175 Output + 181, // %181 = OpCompositeInsert %7 %125 %209 3 + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 10u); + } + { + SCOPED_TRACE("Block 119"); + auto live_sets = register_liveness->Get(119); + std::unordered_set live_inout{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 209, // %209 = OpPhi %7 %197 %109 %181 %118 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 7u); + } + { + SCOPED_TRACE("Block 131"); + auto live_sets = register_liveness->Get(131); + std::unordered_set live_inout{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 205, // %205 = OpPhi %13 %16 %119 %141 %132 + 208, // %208 = OpPhi %7 %209 %119 %183 %132 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 9u); + } + { + SCOPED_TRACE("Block 132"); + auto live_sets = register_liveness->Get(132); + std::unordered_set live_in{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 205, // %205 = OpPhi %13 %16 %119 %141 %132 + 208, // %208 = OpPhi %7 %209 %119 %183 %132 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 141, // %141 = OpIAdd %13 %205 %33 + 176, // %176 = OpVariable %175 Output + 183, // %183 = OpCompositeInsert %7 %138 %208 3 + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 10u); + } + { + SCOPED_TRACE("Block 133"); + auto live_sets = register_liveness->Get(133); + std::unordered_set live_in{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 208, // %208 = OpPhi %7 %209 %119 %183 %132 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 143, // %143 = OpIAdd %13 %196 %33 + 176, // %176 = OpVariable %175 Output + 208, // %208 = OpPhi %7 %209 %119 %183 %132 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 8u); + } + { + SCOPED_TRACE("Block 110"); + auto live_sets = register_liveness->Get(110); + std::unordered_set live_in{ + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 197, // %197 = OpPhi %7 %106 %79 %208 %133 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 150, // %150 = OpVectorShuffle %7 %100 %148 4 5 6 3 + 176, // %176 = OpVariable %175 Output + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 7u); + } + { + SCOPED_TRACE("Block 152"); + auto live_sets = register_liveness->Get(152); + std::unordered_set live_inout{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 199, // %199 = OpPhi %13 %16 %110 %174 %163 + 200, // %200 = OpPhi %7 %150 %110 %203 %163 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 6u); + } + { + SCOPED_TRACE("Block 153"); + auto live_sets = register_liveness->Get(153); + std::unordered_set live_inout{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 199, // %199 = OpPhi %13 %16 %110 %174 %163 + 200, // %200 = OpPhi %7 %150 %110 %203 %163 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 5u); + } + { + SCOPED_TRACE("Block 161"); + auto live_sets = register_liveness->Get(161); + std::unordered_set live_inout{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 199, // %199 = OpPhi %13 %16 %110 %174 %163 + 201, // %201 = OpPhi %13 %16 %153 %172 %162 + 203, // %203 = OpPhi %7 %200 %153 %170 %162 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 7u); + } + { + SCOPED_TRACE("Block 162"); + auto live_sets = register_liveness->Get(162); + std::unordered_set live_in{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 199, // %199 = OpPhi %13 %16 %110 %174 %163 + 201, // %201 = OpPhi %13 %16 %153 %172 %162 + 203, // %203 = OpPhi %7 %200 %153 %170 %162 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 170, // %170 = OpVectorTimesScalar %7 %203 %168 + 172, // %172 = OpIAdd %13 %201 %33 + 176, // %176 = OpVariable %175 Output + 199, // %199 = OpPhi %13 %16 %110 %174 %163 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 8u); + } + { + SCOPED_TRACE("Block 163"); + auto live_sets = register_liveness->Get(163); + std::unordered_set live_in{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 199, // %199 = OpPhi %13 %16 %110 %174 %163 + 203, // %203 = OpPhi %7 %200 %153 %170 %162 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 174, // %174 = OpIAdd %13 %199 %45 + 176, // %176 = OpVariable %175 Output + 203, // %203 = OpPhi %7 %200 %153 %170 %162 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 6u); + } + { + SCOPED_TRACE("Block 154"); + auto live_sets = register_liveness->Get(154); + std::unordered_set live_in{ + 84, // %84 = OpVariable %8 Function + 176, // %176 = OpVariable %175 Output + 200, // %200 = OpPhi %7 %150 %110 %203 %163 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{}; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 4u); + } + + { + SCOPED_TRACE("Compute loop pressure"); + RegisterLiveness::RegionRegisterLiveness loop_reg_pressure; + register_liveness->ComputeLoopRegisterPressure(*ld[39], &loop_reg_pressure); + // Generate(*context->cfg()->block(39), &loop_reg_pressure); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(loop_reg_pressure.live_in_, live_in); + + std::unordered_set live_out{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(loop_reg_pressure.live_out_, live_out); + + EXPECT_EQ(loop_reg_pressure.used_registers_, 13u); + } + + { + SCOPED_TRACE("Loop Fusion simulation"); + RegisterLiveness::RegionRegisterLiveness simulation_resut; + register_liveness->SimulateFusion(*ld[17], *ld[39], &simulation_resut); + + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 24, // %24 = OpVariable %23 Input + 25, // %25 = OpLoad %13 %24 + 28, // %28 = OpVariable %10 Input + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 184, // %184 = OpPhi %13 %16 %5 %34 %18 + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(simulation_resut.live_in_, live_in); + + std::unordered_set live_out{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(simulation_resut.live_out_, live_out); + + EXPECT_EQ(simulation_resut.used_registers_, 17u); + } +} + +TEST_F(PassClassTest, FissionSimulation) { + const std::string source = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "A" + OpName %5 "B" + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpTypeFloat 32 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 10 + %16 = OpTypeArray %13 %15 + %17 = OpTypePointer Function %16 + %18 = OpTypePointer Function %13 + %19 = OpConstant %8 1 + %2 = OpFunction %6 None %7 + %20 = OpLabel + %3 = OpVariable %9 Function + %4 = OpVariable %17 Function + %5 = OpVariable %17 Function + OpBranch %21 + %21 = OpLabel + %22 = OpPhi %8 %10 %20 %23 %24 + OpLoopMerge %25 %24 None + OpBranch %26 + %26 = OpLabel + %27 = OpSLessThan %12 %22 %11 + OpBranchConditional %27 %28 %25 + %28 = OpLabel + %29 = OpAccessChain %18 %5 %22 + %30 = OpLoad %13 %29 + %31 = OpAccessChain %18 %4 %22 + OpStore %31 %30 + %32 = OpAccessChain %18 %4 %22 + %33 = OpLoad %13 %32 + %34 = OpAccessChain %18 %5 %22 + OpStore %34 %33 + OpBranch %24 + %24 = OpLabel + %23 = OpIAdd %8 %22 %19 + OpBranch %21 + %25 = OpLabel + OpStore %3 %22 + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + Function* f = &*module->begin(); + LivenessAnalysis* liveness_analysis = context->GetLivenessAnalysis(); + const RegisterLiveness* register_liveness = liveness_analysis->Get(f); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + analysis::DefUseManager& def_use_mgr = *context->get_def_use_mgr(); + + { + RegisterLiveness::RegionRegisterLiveness l1_sim_resut; + RegisterLiveness::RegionRegisterLiveness l2_sim_resut; + std::unordered_set moved_instructions{ + def_use_mgr.GetDef(29), def_use_mgr.GetDef(30), def_use_mgr.GetDef(31), + def_use_mgr.GetDef(31)->NextNode()}; + std::unordered_set copied_instructions{ + def_use_mgr.GetDef(22), def_use_mgr.GetDef(27), + def_use_mgr.GetDef(27)->NextNode(), def_use_mgr.GetDef(23)}; + + register_liveness->SimulateFission(*ld[21], moved_instructions, + copied_instructions, &l1_sim_resut, + &l2_sim_resut); + { + SCOPED_TRACE("L1 simulation"); + std::unordered_set live_in{ + 3, // %3 = OpVariable %9 Function + 4, // %4 = OpVariable %17 Function + 5, // %5 = OpVariable %17 Function + 22, // %22 = OpPhi %8 %10 %20 %23 %24 + }; + CompareSets(l1_sim_resut.live_in_, live_in); + + std::unordered_set live_out{ + 3, // %3 = OpVariable %9 Function + 4, // %4 = OpVariable %17 Function + 5, // %5 = OpVariable %17 Function + 22, // %22 = OpPhi %8 %10 %20 %23 %24 + }; + CompareSets(l1_sim_resut.live_out_, live_out); + + EXPECT_EQ(l1_sim_resut.used_registers_, 6u); + } + { + SCOPED_TRACE("L2 simulation"); + std::unordered_set live_in{ + 3, // %3 = OpVariable %9 Function + 4, // %4 = OpVariable %17 Function + 5, // %5 = OpVariable %17 Function + 22, // %22 = OpPhi %8 %10 %20 %23 %24 + }; + CompareSets(l2_sim_resut.live_in_, live_in); + + std::unordered_set live_out{ + 3, // %3 = OpVariable %9 Function + 22, // %22 = OpPhi %8 %10 %20 %23 %24 + }; + CompareSets(l2_sim_resut.live_out_, live_out); + + EXPECT_EQ(l2_sim_resut.used_registers_, 6u); + } + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/replace_invalid_opc_test.cpp b/3rdparty/spirv-tools/test/opt/replace_invalid_opc_test.cpp index badff544f..adfe2ee1e 100644 --- a/3rdparty/spirv-tools/test/opt/replace_invalid_opc_test.cpp +++ b/3rdparty/spirv-tools/test/opt/replace_invalid_opc_test.cpp @@ -12,16 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "assembly_builder.h" -#include "gmock/gmock.h" -#include "pass_fixture.h" - #include +#include +#include +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using ReplaceInvalidOpcodeTest = PassTest<::testing::Test>; #ifdef SPIRV_EFFCEE @@ -77,7 +79,7 @@ TEST_F(ReplaceInvalidOpcodeTest, ReplaceInstruction) { OpReturn OpFunctionEnd)"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(ReplaceInvalidOpcodeTest, ReplaceInstructionInNonEntryPoint) { @@ -137,7 +139,7 @@ TEST_F(ReplaceInvalidOpcodeTest, ReplaceInstructionInNonEntryPoint) { OpReturn OpFunctionEnd)"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(ReplaceInvalidOpcodeTest, ReplaceInstructionMultipleEntryPoints) { @@ -206,7 +208,7 @@ TEST_F(ReplaceInvalidOpcodeTest, ReplaceInstructionMultipleEntryPoints) { OpReturn OpFunctionEnd)"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(ReplaceInvalidOpcodeTest, DontReplaceInstruction) { const std::string text = R"( @@ -256,9 +258,9 @@ TEST_F(ReplaceInvalidOpcodeTest, DontReplaceInstruction) { OpReturn OpFunctionEnd)"; - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } TEST_F(ReplaceInvalidOpcodeTest, MultipleEntryPointsDifferentStage) { @@ -321,9 +323,9 @@ TEST_F(ReplaceInvalidOpcodeTest, MultipleEntryPointsDifferentStage) { OpReturn OpFunctionEnd)"; - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } TEST_F(ReplaceInvalidOpcodeTest, DontReplaceLinkage) { @@ -375,9 +377,9 @@ TEST_F(ReplaceInvalidOpcodeTest, DontReplaceLinkage) { OpReturn OpFunctionEnd)"; - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } TEST_F(ReplaceInvalidOpcodeTest, BarrierDontReplace) { @@ -402,9 +404,9 @@ TEST_F(ReplaceInvalidOpcodeTest, BarrierDontReplace) { OpReturn OpFunctionEnd)"; - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } TEST_F(ReplaceInvalidOpcodeTest, BarrierReplace) { @@ -430,7 +432,7 @@ TEST_F(ReplaceInvalidOpcodeTest, BarrierReplace) { OpReturn OpFunctionEnd)"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } struct Message { @@ -516,9 +518,9 @@ TEST_F(ReplaceInvalidOpcodeTest, MessageTest) { "Removing ImageSampleImplicitLod instruction because of incompatible " "execution model."}}; SetMessageConsumer(GetTestMessageConsumer(messages)); - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); } TEST_F(ReplaceInvalidOpcodeTest, MultipleMessageTest) { @@ -582,9 +584,13 @@ TEST_F(ReplaceInvalidOpcodeTest, MultipleMessageTest) { "incompatible " "execution model."}}; SetMessageConsumer(GetTestMessageConsumer(messages)); - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); } + #endif -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/scalar_analysis.cpp b/3rdparty/spirv-tools/test/opt/scalar_analysis.cpp index a73953eb8..598d8c7b7 100644 --- a/3rdparty/spirv-tools/test/opt/scalar_analysis.cpp +++ b/3rdparty/spirv-tools/test/opt/scalar_analysis.cpp @@ -12,29 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include #include #include #include -#include "assembly_builder.h" -#include "function_utils.h" -#include "pass_fixture.h" -#include "pass_utils.h" - -#include "opt/iterator.h" -#include "opt/loop_descriptor.h" -#include "opt/pass.h" -#include "opt/scalar_analysis.h" -#include "opt/tree_iterator.h" +#include "gmock/gmock.h" +#include "source/opt/iterator.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/tree_iterator.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using ::testing::UnorderedElementsAre; - using ScalarAnalysisTest = PassTest<::testing::Test>; /* @@ -99,18 +97,18 @@ TEST_F(ScalarAnalysisTest, BasicEvolutionTest) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 4); - opt::ScalarEvolutionAnalysis analysis{context.get()}; + const Function* f = spvtest::GetFunction(module, 4); + ScalarEvolutionAnalysis analysis{context.get()}; - const ir::Instruction* store = nullptr; - const ir::Instruction* load = nullptr; - for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 11)) { + const Instruction* store = nullptr; + const Instruction* load = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 11)) { if (inst.opcode() == SpvOp::SpvOpStore) { store = &inst; } @@ -122,36 +120,35 @@ TEST_F(ScalarAnalysisTest, BasicEvolutionTest) { EXPECT_NE(load, nullptr); EXPECT_NE(store, nullptr); - ir::Instruction* access_chain = + Instruction* access_chain = context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0)); - ir::Instruction* child = context->get_def_use_mgr()->GetDef( + Instruction* child = context->get_def_use_mgr()->GetDef( access_chain->GetSingleWordInOperand(1)); - const opt::SENode* node = analysis.AnalyzeInstruction(child); + const SENode* node = analysis.AnalyzeInstruction(child); EXPECT_NE(node, nullptr); // Unsimplified node should have the form of ADD(REC(0,1), 1) - EXPECT_EQ(node->GetType(), opt::SENode::Add); + EXPECT_EQ(node->GetType(), SENode::Add); - const opt::SENode* child_1 = node->GetChild(0); - EXPECT_TRUE(child_1->GetType() == opt::SENode::Constant || - child_1->GetType() == opt::SENode::RecurrentAddExpr); + const SENode* child_1 = node->GetChild(0); + EXPECT_TRUE(child_1->GetType() == SENode::Constant || + child_1->GetType() == SENode::RecurrentAddExpr); - const opt::SENode* child_2 = node->GetChild(1); - EXPECT_TRUE(child_2->GetType() == opt::SENode::Constant || - child_2->GetType() == opt::SENode::RecurrentAddExpr); + const SENode* child_2 = node->GetChild(1); + EXPECT_TRUE(child_2->GetType() == SENode::Constant || + child_2->GetType() == SENode::RecurrentAddExpr); - opt::SENode* simplified = - analysis.SimplifyExpression(const_cast(node)); + SENode* simplified = analysis.SimplifyExpression(const_cast(node)); // Simplified should be in the form of REC(1,1) - EXPECT_EQ(simplified->GetType(), opt::SENode::RecurrentAddExpr); + EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr); - EXPECT_EQ(simplified->GetChild(0)->GetType(), opt::SENode::Constant); + EXPECT_EQ(simplified->GetChild(0)->GetType(), SENode::Constant); EXPECT_EQ(simplified->GetChild(0)->AsSEConstantNode()->FoldToSingleValue(), 1); - EXPECT_EQ(simplified->GetChild(1)->GetType(), opt::SENode::Constant); + EXPECT_EQ(simplified->GetChild(1)->GetType(), SENode::Constant); EXPECT_EQ(simplified->GetChild(1)->AsSEConstantNode()->FoldToSingleValue(), 1); @@ -228,17 +225,17 @@ TEST_F(ScalarAnalysisTest, LoadTest) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - opt::ScalarEvolutionAnalysis analysis{context.get()}; + const Function* f = spvtest::GetFunction(module, 2); + ScalarEvolutionAnalysis analysis{context.get()}; - const ir::Instruction* load = nullptr; - for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 28)) { + const Instruction* load = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 28)) { if (inst.opcode() == SpvOp::SpvOpLoad) { load = &inst; } @@ -246,40 +243,39 @@ TEST_F(ScalarAnalysisTest, LoadTest) { EXPECT_NE(load, nullptr); - ir::Instruction* access_chain = + Instruction* access_chain = context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0)); - ir::Instruction* child = context->get_def_use_mgr()->GetDef( + Instruction* child = context->get_def_use_mgr()->GetDef( access_chain->GetSingleWordInOperand(1)); - // const opt::SENode* node = + // const SENode* node = // analysis.GetNodeFromInstruction(child->unique_id()); - const opt::SENode* node = analysis.AnalyzeInstruction(child); + const SENode* node = analysis.AnalyzeInstruction(child); EXPECT_NE(node, nullptr); // Unsimplified node should have the form of ADD(REC(0,1), X) - EXPECT_EQ(node->GetType(), opt::SENode::Add); + EXPECT_EQ(node->GetType(), SENode::Add); - const opt::SENode* child_1 = node->GetChild(0); - EXPECT_TRUE(child_1->GetType() == opt::SENode::ValueUnknown || - child_1->GetType() == opt::SENode::RecurrentAddExpr); + const SENode* child_1 = node->GetChild(0); + EXPECT_TRUE(child_1->GetType() == SENode::ValueUnknown || + child_1->GetType() == SENode::RecurrentAddExpr); - const opt::SENode* child_2 = node->GetChild(1); - EXPECT_TRUE(child_2->GetType() == opt::SENode::ValueUnknown || - child_2->GetType() == opt::SENode::RecurrentAddExpr); + const SENode* child_2 = node->GetChild(1); + EXPECT_TRUE(child_2->GetType() == SENode::ValueUnknown || + child_2->GetType() == SENode::RecurrentAddExpr); - opt::SENode* simplified = - analysis.SimplifyExpression(const_cast(node)); - EXPECT_EQ(simplified->GetType(), opt::SENode::RecurrentAddExpr); + SENode* simplified = analysis.SimplifyExpression(const_cast(node)); + EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr); - const opt::SERecurrentNode* rec = simplified->AsSERecurrentNode(); + const SERecurrentNode* rec = simplified->AsSERecurrentNode(); EXPECT_NE(rec->GetChild(0), rec->GetChild(1)); - EXPECT_EQ(rec->GetOffset()->GetType(), opt::SENode::ValueUnknown); + EXPECT_EQ(rec->GetOffset()->GetType(), SENode::ValueUnknown); - EXPECT_EQ(rec->GetCoefficient()->GetType(), opt::SENode::Constant); + EXPECT_EQ(rec->GetCoefficient()->GetType(), SENode::Constant); EXPECT_EQ(rec->GetCoefficient()->AsSEConstantNode()->FoldToSingleValue(), 1u); } @@ -345,17 +341,17 @@ TEST_F(ScalarAnalysisTest, SimplifySimple) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - opt::ScalarEvolutionAnalysis analysis{context.get()}; + const Function* f = spvtest::GetFunction(module, 2); + ScalarEvolutionAnalysis analysis{context.get()}; - const ir::Instruction* load = nullptr; - for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 21)) { + const Instruction* load = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) { if (inst.opcode() == SpvOp::SpvOpLoad && inst.result_id() == 33) { load = &inst; } @@ -363,24 +359,23 @@ TEST_F(ScalarAnalysisTest, SimplifySimple) { EXPECT_NE(load, nullptr); - ir::Instruction* access_chain = + Instruction* access_chain = context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0)); - ir::Instruction* child = context->get_def_use_mgr()->GetDef( + Instruction* child = context->get_def_use_mgr()->GetDef( access_chain->GetSingleWordInOperand(1)); - const opt::SENode* node = analysis.AnalyzeInstruction(child); + const SENode* node = analysis.AnalyzeInstruction(child); // Unsimplified is a very large graph with an add at the top. EXPECT_NE(node, nullptr); - EXPECT_EQ(node->GetType(), opt::SENode::Add); + EXPECT_EQ(node->GetType(), SENode::Add); // Simplified node should resolve down to a constant expression as the loads // will eliminate themselves. - opt::SENode* simplified = - analysis.SimplifyExpression(const_cast(node)); + SENode* simplified = analysis.SimplifyExpression(const_cast(node)); - EXPECT_EQ(simplified->GetType(), opt::SENode::Constant); + EXPECT_EQ(simplified->GetType(), SENode::Constant); EXPECT_EQ(simplified->AsSEConstantNode()->FoldToSingleValue(), 33u); } @@ -496,21 +491,21 @@ TEST_F(ScalarAnalysisTest, Simplify) { OpFunctionEnd )"; // clang-format on - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 4); - opt::ScalarEvolutionAnalysis analysis{context.get()}; + const Function* f = spvtest::GetFunction(module, 4); + ScalarEvolutionAnalysis analysis{context.get()}; - const ir::Instruction* loads[6]; - const ir::Instruction* stores[6]; + const Instruction* loads[6]; + const Instruction* stores[6]; int load_count = 0; int store_count = 0; - for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 22)) { + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 22)) { if (inst.opcode() == SpvOp::SpvOpLoad) { loads[load_count] = &inst; ++load_count; @@ -524,14 +519,14 @@ TEST_F(ScalarAnalysisTest, Simplify) { EXPECT_EQ(load_count, 6); EXPECT_EQ(store_count, 6); - ir::Instruction* load_access_chain; - ir::Instruction* store_access_chain; - ir::Instruction* load_child; - ir::Instruction* store_child; - opt::SENode* load_node; - opt::SENode* store_node; - opt::SENode* subtract_node; - opt::SENode* simplified_node; + Instruction* load_access_chain; + Instruction* store_access_chain; + Instruction* load_child; + Instruction* store_child; + SENode* load_node; + SENode* store_node; + SENode* subtract_node; + SENode* simplified_node; // Testing [i] - [i] == 0 load_access_chain = @@ -549,7 +544,7 @@ TEST_F(ScalarAnalysisTest, Simplify) { subtract_node = analysis.CreateSubtraction(store_node, load_node); simplified_node = analysis.SimplifyExpression(subtract_node); - EXPECT_EQ(simplified_node->GetType(), opt::SENode::Constant); + EXPECT_EQ(simplified_node->GetType(), SENode::Constant); EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u); // Testing [i] - [i-1] == 1 @@ -569,7 +564,7 @@ TEST_F(ScalarAnalysisTest, Simplify) { subtract_node = analysis.CreateSubtraction(store_node, load_node); simplified_node = analysis.SimplifyExpression(subtract_node); - EXPECT_EQ(simplified_node->GetType(), opt::SENode::Constant); + EXPECT_EQ(simplified_node->GetType(), SENode::Constant); EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 1u); // Testing [i] - [i+1] == -1 @@ -588,7 +583,7 @@ TEST_F(ScalarAnalysisTest, Simplify) { subtract_node = analysis.CreateSubtraction(store_node, load_node); simplified_node = analysis.SimplifyExpression(subtract_node); - EXPECT_EQ(simplified_node->GetType(), opt::SENode::Constant); + EXPECT_EQ(simplified_node->GetType(), SENode::Constant); EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), -1); // Testing [i+1] - [i+1] == 0 @@ -607,7 +602,7 @@ TEST_F(ScalarAnalysisTest, Simplify) { subtract_node = analysis.CreateSubtraction(store_node, load_node); simplified_node = analysis.SimplifyExpression(subtract_node); - EXPECT_EQ(simplified_node->GetType(), opt::SENode::Constant); + EXPECT_EQ(simplified_node->GetType(), SENode::Constant); EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u); // Testing [i+N] - [i+N] == 0 @@ -627,7 +622,7 @@ TEST_F(ScalarAnalysisTest, Simplify) { subtract_node = analysis.CreateSubtraction(store_node, load_node); simplified_node = analysis.SimplifyExpression(subtract_node); - EXPECT_EQ(simplified_node->GetType(), opt::SENode::Constant); + EXPECT_EQ(simplified_node->GetType(), SENode::Constant); EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u); // Testing [i] - [i+N] == -N @@ -646,7 +641,7 @@ TEST_F(ScalarAnalysisTest, Simplify) { subtract_node = analysis.CreateSubtraction(store_node, load_node); simplified_node = analysis.SimplifyExpression(subtract_node); - EXPECT_EQ(simplified_node->GetType(), opt::SENode::Negative); + EXPECT_EQ(simplified_node->GetType(), SENode::Negative); } /* @@ -735,21 +730,21 @@ TEST_F(ScalarAnalysisTest, SimplifyMultiplyInductions) { OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - opt::ScalarEvolutionAnalysis analysis{context.get()}; + const Function* f = spvtest::GetFunction(module, 2); + ScalarEvolutionAnalysis analysis{context.get()}; - const ir::Instruction* loads[2] = {nullptr, nullptr}; - const ir::Instruction* stores[2] = {nullptr, nullptr}; + const Instruction* loads[2] = {nullptr, nullptr}; + const Instruction* stores[2] = {nullptr, nullptr}; int load_count = 0; int store_count = 0; - for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 31)) { + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 31)) { if (inst.opcode() == SpvOp::SpvOpLoad) { loads[load_count] = &inst; ++load_count; @@ -763,19 +758,19 @@ TEST_F(ScalarAnalysisTest, SimplifyMultiplyInductions) { EXPECT_EQ(load_count, 2); EXPECT_EQ(store_count, 2); - ir::Instruction* load_access_chain = + Instruction* load_access_chain = context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0)); - ir::Instruction* store_access_chain = + Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0)); - ir::Instruction* load_child = context->get_def_use_mgr()->GetDef( + Instruction* load_child = context->get_def_use_mgr()->GetDef( load_access_chain->GetSingleWordInOperand(1)); - ir::Instruction* store_child = context->get_def_use_mgr()->GetDef( + Instruction* store_child = context->get_def_use_mgr()->GetDef( store_access_chain->GetSingleWordInOperand(1)); - opt::SENode* store_node = analysis.AnalyzeInstruction(store_child); + SENode* store_node = analysis.AnalyzeInstruction(store_child); - opt::SENode* store_simplified = analysis.SimplifyExpression(store_node); + SENode* store_simplified = analysis.SimplifyExpression(store_node); load_access_chain = context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0)); @@ -786,11 +781,11 @@ TEST_F(ScalarAnalysisTest, SimplifyMultiplyInductions) { store_child = context->get_def_use_mgr()->GetDef( store_access_chain->GetSingleWordInOperand(1)); - opt::SENode* second_store = + SENode* second_store = analysis.SimplifyExpression(analysis.AnalyzeInstruction(store_child)); - opt::SENode* second_load = + SENode* second_load = analysis.SimplifyExpression(analysis.AnalyzeInstruction(load_child)); - opt::SENode* combined_add = analysis.SimplifyExpression( + SENode* combined_add = analysis.SimplifyExpression( analysis.CreateAddNode(second_load, second_store)); // We're checking that the two recurrent expression have been correctly @@ -872,19 +867,19 @@ TEST_F(ScalarAnalysisTest, SimplifyNegativeSteps) { OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - opt::ScalarEvolutionAnalysis analysis{context.get()}; + const Function* f = spvtest::GetFunction(module, 2); + ScalarEvolutionAnalysis analysis{context.get()}; - const ir::Instruction* loads[1] = {nullptr}; + const Instruction* loads[1] = {nullptr}; int load_count = 0; - for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 29)) { + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 29)) { if (inst.opcode() == SpvOp::SpvOpLoad) { loads[load_count] = &inst; ++load_count; @@ -893,38 +888,38 @@ TEST_F(ScalarAnalysisTest, SimplifyNegativeSteps) { EXPECT_EQ(load_count, 1); - ir::Instruction* load_access_chain = + Instruction* load_access_chain = context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0)); - ir::Instruction* load_child = context->get_def_use_mgr()->GetDef( + Instruction* load_child = context->get_def_use_mgr()->GetDef( load_access_chain->GetSingleWordInOperand(1)); - opt::SENode* load_node = analysis.AnalyzeInstruction(load_child); + SENode* load_node = analysis.AnalyzeInstruction(load_child); EXPECT_TRUE(load_node); - EXPECT_EQ(load_node->GetType(), opt::SENode::RecurrentAddExpr); + EXPECT_EQ(load_node->GetType(), SENode::RecurrentAddExpr); EXPECT_TRUE(load_node->AsSERecurrentNode()); - opt::SENode* child_1 = load_node->AsSERecurrentNode()->GetCoefficient(); - opt::SENode* child_2 = load_node->AsSERecurrentNode()->GetOffset(); + SENode* child_1 = load_node->AsSERecurrentNode()->GetCoefficient(); + SENode* child_2 = load_node->AsSERecurrentNode()->GetOffset(); - EXPECT_EQ(child_1->GetType(), opt::SENode::Constant); - EXPECT_EQ(child_2->GetType(), opt::SENode::Constant); + EXPECT_EQ(child_1->GetType(), SENode::Constant); + EXPECT_EQ(child_2->GetType(), SENode::Constant); EXPECT_EQ(child_1->AsSEConstantNode()->FoldToSingleValue(), -1); EXPECT_EQ(child_2->AsSEConstantNode()->FoldToSingleValue(), 0u); - opt::SERecurrentNode* load_simplified = + SERecurrentNode* load_simplified = analysis.SimplifyExpression(load_node)->AsSERecurrentNode(); EXPECT_TRUE(load_simplified); EXPECT_EQ(load_node, load_simplified); - EXPECT_EQ(load_simplified->GetType(), opt::SENode::RecurrentAddExpr); + EXPECT_EQ(load_simplified->GetType(), SENode::RecurrentAddExpr); EXPECT_TRUE(load_simplified->AsSERecurrentNode()); - opt::SENode* simplified_child_1 = + SENode* simplified_child_1 = load_simplified->AsSERecurrentNode()->GetCoefficient(); - opt::SENode* simplified_child_2 = + SENode* simplified_child_2 = load_simplified->AsSERecurrentNode()->GetOffset(); EXPECT_EQ(child_1, simplified_child_1); @@ -1017,19 +1012,19 @@ TEST_F(ScalarAnalysisTest, SimplifyInductionsAndLoads) { OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - opt::ScalarEvolutionAnalysis analysis{context.get()}; + const Function* f = spvtest::GetFunction(module, 2); + ScalarEvolutionAnalysis analysis{context.get()}; - std::vector loads{}; - std::vector stores{}; + std::vector loads{}; + std::vector stores{}; - for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 30)) { + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 30)) { if (inst.opcode() == SpvOp::SpvOpLoad) { loads.push_back(&inst); } @@ -1041,81 +1036,77 @@ TEST_F(ScalarAnalysisTest, SimplifyInductionsAndLoads) { EXPECT_EQ(loads.size(), 3u); EXPECT_EQ(stores.size(), 2u); { - ir::Instruction* store_access_chain = context->get_def_use_mgr()->GetDef( + Instruction* store_access_chain = context->get_def_use_mgr()->GetDef( stores[0]->GetSingleWordInOperand(0)); - ir::Instruction* store_child = context->get_def_use_mgr()->GetDef( + Instruction* store_child = context->get_def_use_mgr()->GetDef( store_access_chain->GetSingleWordInOperand(1)); - opt::SENode* store_node = analysis.AnalyzeInstruction(store_child); + SENode* store_node = analysis.AnalyzeInstruction(store_child); - opt::SENode* store_simplified = analysis.SimplifyExpression(store_node); + SENode* store_simplified = analysis.SimplifyExpression(store_node); - ir::Instruction* load_access_chain = + Instruction* load_access_chain = context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0)); - ir::Instruction* load_child = context->get_def_use_mgr()->GetDef( + Instruction* load_child = context->get_def_use_mgr()->GetDef( load_access_chain->GetSingleWordInOperand(1)); - opt::SENode* load_node = analysis.AnalyzeInstruction(load_child); + SENode* load_node = analysis.AnalyzeInstruction(load_child); - opt::SENode* load_simplified = analysis.SimplifyExpression(load_node); + SENode* load_simplified = analysis.SimplifyExpression(load_node); - opt::SENode* difference = + SENode* difference = analysis.CreateSubtraction(store_simplified, load_simplified); - opt::SENode* difference_simplified = - analysis.SimplifyExpression(difference); + SENode* difference_simplified = analysis.SimplifyExpression(difference); // Check that i+2*N - i*N, turns into just N when both sides have already // been simplified into a single recurrent expression. - EXPECT_EQ(difference_simplified->GetType(), opt::SENode::ValueUnknown); + EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown); // Check that the inverse, i*N - i+2*N turns into -N. - opt::SENode* difference_inverse = analysis.SimplifyExpression( + SENode* difference_inverse = analysis.SimplifyExpression( analysis.CreateSubtraction(load_simplified, store_simplified)); - EXPECT_EQ(difference_inverse->GetType(), opt::SENode::Negative); - EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), - opt::SENode::ValueUnknown); + EXPECT_EQ(difference_inverse->GetType(), SENode::Negative); + EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown); EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified); } { - ir::Instruction* store_access_chain = context->get_def_use_mgr()->GetDef( + Instruction* store_access_chain = context->get_def_use_mgr()->GetDef( stores[1]->GetSingleWordInOperand(0)); - ir::Instruction* store_child = context->get_def_use_mgr()->GetDef( + Instruction* store_child = context->get_def_use_mgr()->GetDef( store_access_chain->GetSingleWordInOperand(1)); - opt::SENode* store_node = analysis.AnalyzeInstruction(store_child); - opt::SENode* store_simplified = analysis.SimplifyExpression(store_node); + SENode* store_node = analysis.AnalyzeInstruction(store_child); + SENode* store_simplified = analysis.SimplifyExpression(store_node); - ir::Instruction* load_access_chain = + Instruction* load_access_chain = context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0)); - ir::Instruction* load_child = context->get_def_use_mgr()->GetDef( + Instruction* load_child = context->get_def_use_mgr()->GetDef( load_access_chain->GetSingleWordInOperand(1)); - opt::SENode* load_node = analysis.AnalyzeInstruction(load_child); + SENode* load_node = analysis.AnalyzeInstruction(load_child); - opt::SENode* load_simplified = analysis.SimplifyExpression(load_node); + SENode* load_simplified = analysis.SimplifyExpression(load_node); - opt::SENode* difference = + SENode* difference = analysis.CreateSubtraction(store_simplified, load_simplified); - opt::SENode* difference_simplified = - analysis.SimplifyExpression(difference); + SENode* difference_simplified = analysis.SimplifyExpression(difference); // Check that 2*i + 2*N + 1 - 2*i + N + 1, turns into just N when both // sides have already been simplified into a single recurrent expression. - EXPECT_EQ(difference_simplified->GetType(), opt::SENode::ValueUnknown); + EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown); // Check that the inverse, (2*i + N + 1) - (2*i + 2*N + 1) turns into -N. - opt::SENode* difference_inverse = analysis.SimplifyExpression( + SENode* difference_inverse = analysis.SimplifyExpression( analysis.CreateSubtraction(load_simplified, store_simplified)); - EXPECT_EQ(difference_inverse->GetType(), opt::SENode::Negative); - EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), - opt::SENode::ValueUnknown); + EXPECT_EQ(difference_inverse->GetType(), SENode::Negative); + EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown); EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified); } } @@ -1191,38 +1182,40 @@ TEST_F(ScalarAnalysisTest, InductionWithVariantStep) { OpReturn OpFunctionEnd )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - ir::Module* module = context->module(); + Module* module = context->module(); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << text << std::endl; - const ir::Function* f = spvtest::GetFunction(module, 2); - opt::ScalarEvolutionAnalysis analysis{context.get()}; + const Function* f = spvtest::GetFunction(module, 2); + ScalarEvolutionAnalysis analysis{context.get()}; - std::vector phis{}; + std::vector phis{}; - for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 21)) { + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) { if (inst.opcode() == SpvOp::SpvOpPhi) { phis.push_back(&inst); } } EXPECT_EQ(phis.size(), 2u); - opt::SENode* phi_node_1 = analysis.AnalyzeInstruction(phis[0]); - opt::SENode* phi_node_2 = analysis.AnalyzeInstruction(phis[1]); + SENode* phi_node_1 = analysis.AnalyzeInstruction(phis[0]); + SENode* phi_node_2 = analysis.AnalyzeInstruction(phis[1]); phi_node_1->DumpDot(std::cout, true); EXPECT_NE(phi_node_1, nullptr); EXPECT_NE(phi_node_2, nullptr); - EXPECT_EQ(phi_node_1->GetType(), opt::SENode::RecurrentAddExpr); - EXPECT_EQ(phi_node_2->GetType(), opt::SENode::CanNotCompute); + EXPECT_EQ(phi_node_1->GetType(), SENode::RecurrentAddExpr); + EXPECT_EQ(phi_node_2->GetType(), SENode::CanNotCompute); - opt::SENode* simplified_1 = analysis.SimplifyExpression(phi_node_1); - opt::SENode* simplified_2 = analysis.SimplifyExpression(phi_node_2); + SENode* simplified_1 = analysis.SimplifyExpression(phi_node_1); + SENode* simplified_2 = analysis.SimplifyExpression(phi_node_2); - EXPECT_EQ(simplified_1->GetType(), opt::SENode::RecurrentAddExpr); - EXPECT_EQ(simplified_2->GetType(), opt::SENode::CanNotCompute); + EXPECT_EQ(simplified_1->GetType(), SENode::RecurrentAddExpr); + EXPECT_EQ(simplified_2->GetType(), SENode::CanNotCompute); } } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/scalar_replacement_test.cpp b/3rdparty/spirv-tools/test/opt/scalar_replacement_test.cpp index 9c7d631eb..652978bb0 100644 --- a/3rdparty/spirv-tools/test/opt/scalar_replacement_test.cpp +++ b/3rdparty/spirv-tools/test/opt/scalar_replacement_test.cpp @@ -12,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "assembly_builder.h" +#include + #include "gmock/gmock.h" -#include "pass_fixture.h" -#include "pass_utils.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using ScalarReplacementTest = PassTest<::testing::Test>; // TODO(dneto): Add Effcee as required dependency, and make this unconditional. @@ -71,7 +73,7 @@ OpReturnValue %19 OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, StructInitialization) { @@ -125,7 +127,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, SpecConstantInitialization) { @@ -169,121 +171,122 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } // TODO(alanbaker): Re-enable when vector and matrix scalarization is supported. // TEST_F(ScalarReplacementTest, VectorInitialization) { // const std::string text = R"( -//; -//; CHECK: [[elem:%\w+]] = OpTypeInt 32 0 -//; CHECK: [[vector:%\w+]] = OpTypeVector [[elem]] 4 -//; CHECK: [[vector_ptr:%\w+]] = OpTypePointer Function [[vector]] -//; CHECK: [[elem_ptr:%\w+]] = OpTypePointer Function [[elem]] -//; CHECK: [[zero:%\w+]] = OpConstant [[elem]] 0 -//; CHECK: [[undef:%\w+]] = OpUndef [[elem]] -//; CHECK: [[two:%\w+]] = OpConstant [[elem]] 2 -//; CHECK: [[null:%\w+]] = OpConstantNull [[elem]] -//; CHECK-NOT: OpVariable [[vector_ptr]] -//; CHECK: OpVariable [[elem_ptr]] Function [[zero]] -//; CHECK-NOT: OpVariable [[elem_ptr]] Function [[undef]] -//; CHECK-NEXT: OpVariable [[elem_ptr]] Function -//; CHECK-NEXT: OpVariable [[elem_ptr]] Function [[two]] -//; CHECK-NEXT: OpVariable [[elem_ptr]] Function [[null]] -//; CHECK-NOT: OpVariable [[elem_ptr]] Function [[undef]] -//; -// OpCapability Shader -// OpCapability Linkage -// OpMemoryModel Logical GLSL450 -// OpName %6 "vector_init" -//%1 = OpTypeVoid -//%2 = OpTypeInt 32 0 -//%3 = OpTypeVector %2 4 -//%4 = OpTypePointer Function %3 -//%20 = OpTypePointer Function %2 -//%6 = OpTypeFunction %1 -//%7 = OpConstant %2 0 -//%8 = OpUndef %2 -//%9 = OpConstant %2 2 -//%30 = OpConstant %2 1 -//%31 = OpConstant %2 3 -//%10 = OpConstantNull %2 -//%11 = OpConstantComposite %3 %10 %9 %8 %7 -//%12 = OpFunction %1 None %6 -//%13 = OpLabel -//%14 = OpVariable %4 Function %11 -//%15 = OpAccessChain %20 %14 %7 -// OpStore %15 %10 -//%16 = OpAccessChain %20 %14 %9 -// OpStore %16 %10 -//%17 = OpAccessChain %20 %14 %30 -// OpStore %17 %10 -//%18 = OpAccessChain %20 %14 %31 -// OpStore %18 %10 -// OpReturn -// OpFunctionEnd -// )"; +// ; +// ; CHECK: [[elem:%\w+]] = OpTypeInt 32 0 +// ; CHECK: [[vector:%\w+]] = OpTypeVector [[elem]] 4 +// ; CHECK: [[vector_ptr:%\w+]] = OpTypePointer Function [[vector]] +// ; CHECK: [[elem_ptr:%\w+]] = OpTypePointer Function [[elem]] +// ; CHECK: [[zero:%\w+]] = OpConstant [[elem]] 0 +// ; CHECK: [[undef:%\w+]] = OpUndef [[elem]] +// ; CHECK: [[two:%\w+]] = OpConstant [[elem]] 2 +// ; CHECK: [[null:%\w+]] = OpConstantNull [[elem]] +// ; CHECK-NOT: OpVariable [[vector_ptr]] +// ; CHECK: OpVariable [[elem_ptr]] Function [[zero]] +// ; CHECK-NOT: OpVariable [[elem_ptr]] Function [[undef]] +// ; CHECK-NEXT: OpVariable [[elem_ptr]] Function +// ; CHECK-NEXT: OpVariable [[elem_ptr]] Function [[two]] +// ; CHECK-NEXT: OpVariable [[elem_ptr]] Function [[null]] +// ; CHECK-NOT: OpVariable [[elem_ptr]] Function [[undef]] +// ; +// OpCapability Shader +// OpCapability Linkage +// OpMemoryModel Logical GLSL450 +// OpName %6 "vector_init" +// %1 = OpTypeVoid +// %2 = OpTypeInt 32 0 +// %3 = OpTypeVector %2 4 +// %4 = OpTypePointer Function %3 +// %20 = OpTypePointer Function %2 +// %6 = OpTypeFunction %1 +// %7 = OpConstant %2 0 +// %8 = OpUndef %2 +// %9 = OpConstant %2 2 +// %30 = OpConstant %2 1 +// %31 = OpConstant %2 3 +// %10 = OpConstantNull %2 +// %11 = OpConstantComposite %3 %10 %9 %8 %7 +// %12 = OpFunction %1 None %6 +// %13 = OpLabel +// %14 = OpVariable %4 Function %11 +// %15 = OpAccessChain %20 %14 %7 +// OpStore %15 %10 +// %16 = OpAccessChain %20 %14 %9 +// OpStore %16 %10 +// %17 = OpAccessChain %20 %14 %30 +// OpStore %17 %10 +// %18 = OpAccessChain %20 %14 %31 +// OpStore %18 %10 +// OpReturn +// OpFunctionEnd +// )"; // -// SinglePassRunAndMatch(text, true); -//} +// SinglePassRunAndMatch(text, true); +// } // -// TEST_F(ScalarReplacementTest, MatrixInitialization) { -// const std::string text = R"( -//; -//; CHECK: [[float:%\w+]] = OpTypeFloat 32 -//; CHECK: [[vector:%\w+]] = OpTypeVector [[float]] 2 -//; CHECK: [[matrix:%\w+]] = OpTypeMatrix [[vector]] 2 -//; CHECK: [[matrix_ptr:%\w+]] = OpTypePointer Function [[matrix]] -//; CHECK: [[float_ptr:%\w+]] = OpTypePointer Function [[float]] -//; CHECK: [[vec_ptr:%\w+]] = OpTypePointer Function [[vector]] -//; CHECK: [[zerof:%\w+]] = OpConstant [[float]] 0 -//; CHECK: [[onef:%\w+]] = OpConstant [[float]] 1 -//; CHECK: [[one_zero:%\w+]] = OpConstantComposite [[vector]] [[onef]] [[zerof]] -//; CHECK: [[zero_one:%\w+]] = OpConstantComposite [[vector]] [[zerof]] [[onef]] -//; CHECK: [[const_mat:%\w+]] = OpConstantComposite [[matrix]] [[one_zero]] -//[[zero_one]] ; CHECK-NOT: OpVariable [[matrix]] ; CHECK-NOT: OpVariable -//[[vector]] Function [[one_zero]] ; CHECK: [[f1:%\w+]] = OpVariable -//[[float_ptr]] Function [[zerof]] ; CHECK-NEXT: [[f2:%\w+]] = OpVariable -//[[float_ptr]] Function [[onef]] ; CHECK-NEXT: [[vec_var:%\w+]] = OpVariable -//[[vec_ptr]] Function [[zero_one]] ; CHECK-NOT: OpVariable [[matrix]] ; -// CHECK-NOT: OpVariable [[vector]] Function [[one_zero]] -//; -// OpCapability Shader -// OpCapability Linkage -// OpMemoryModel Logical GLSL450 -// OpName %7 "matrix_init" -//%1 = OpTypeVoid -//%2 = OpTypeFloat 32 -//%3 = OpTypeVector %2 2 -//%4 = OpTypeMatrix %3 2 -//%5 = OpTypePointer Function %4 -//%6 = OpTypePointer Function %2 -//%30 = OpTypePointer Function %3 -//%10 = OpTypeInt 32 0 -//%7 = OpTypeFunction %1 %10 -//%8 = OpConstant %2 0.0 -//%9 = OpConstant %2 1.0 -//%11 = OpConstant %10 0 -//%12 = OpConstant %10 1 -//%13 = OpConstantComposite %3 %9 %8 -//%14 = OpConstantComposite %3 %8 %9 -//%15 = OpConstantComposite %4 %13 %14 -//%16 = OpFunction %1 None %7 -//%31 = OpFunctionParameter %10 -//%17 = OpLabel -//%18 = OpVariable %5 Function %15 -//%19 = OpAccessChain %6 %18 %11 %12 -// OpStore %19 %8 -//%20 = OpAccessChain %6 %18 %11 %11 -// OpStore %20 %8 -//%21 = OpAccessChain %30 %18 %12 -// OpStore %21 %14 -// OpReturn -// OpFunctionEnd -// )"; +// TEST_F(ScalarReplacementTest, MatrixInitialization) { +// const std::string text = R"( +// ; +// ; CHECK: [[float:%\w+]] = OpTypeFloat 32 +// ; CHECK: [[vector:%\w+]] = OpTypeVector [[float]] 2 +// ; CHECK: [[matrix:%\w+]] = OpTypeMatrix [[vector]] 2 +// ; CHECK: [[matrix_ptr:%\w+]] = OpTypePointer Function [[matrix]] +// ; CHECK: [[float_ptr:%\w+]] = OpTypePointer Function [[float]] +// ; CHECK: [[vec_ptr:%\w+]] = OpTypePointer Function [[vector]] +// ; CHECK: [[zerof:%\w+]] = OpConstant [[float]] 0 +// ; CHECK: [[onef:%\w+]] = OpConstant [[float]] 1 +// ; CHECK: [[one_zero:%\w+]] = OpConstantComposite [[vector]] [[onef]] +// [[zerof]] ; CHECK: [[zero_one:%\w+]] = OpConstantComposite [[vector]] +// [[zerof]] [[onef]] ; CHECK: [[const_mat:%\w+]] = OpConstantComposite +// [[matrix]] [[one_zero]] +// [[zero_one]] ; CHECK-NOT: OpVariable [[matrix]] ; CHECK-NOT: OpVariable +// [[vector]] Function [[one_zero]] ; CHECK: [[f1:%\w+]] = OpVariable +// [[float_ptr]] Function [[zerof]] ; CHECK-NEXT: [[f2:%\w+]] = OpVariable +// [[float_ptr]] Function [[onef]] ; CHECK-NEXT: [[vec_var:%\w+]] = OpVariable +// [[vec_ptr]] Function [[zero_one]] ; CHECK-NOT: OpVariable [[matrix]] ; +// CHECK-NOT: OpVariable [[vector]] Function [[one_zero]] +// ; +// OpCapability Shader +// OpCapability Linkage +// OpMemoryModel Logical GLSL450 +// OpName %7 "matrix_init" +// %1 = OpTypeVoid +// %2 = OpTypeFloat 32 +// %3 = OpTypeVector %2 2 +// %4 = OpTypeMatrix %3 2 +// %5 = OpTypePointer Function %4 +// %6 = OpTypePointer Function %2 +// %30 = OpTypePointer Function %3 +// %10 = OpTypeInt 32 0 +// %7 = OpTypeFunction %1 %10 +// %8 = OpConstant %2 0.0 +// %9 = OpConstant %2 1.0 +// %11 = OpConstant %10 0 +// %12 = OpConstant %10 1 +// %13 = OpConstantComposite %3 %9 %8 +// %14 = OpConstantComposite %3 %8 %9 +// %15 = OpConstantComposite %4 %13 %14 +// %16 = OpFunction %1 None %7 +// %31 = OpFunctionParameter %10 +// %17 = OpLabel +// %18 = OpVariable %5 Function %15 +// %19 = OpAccessChain %6 %18 %11 %12 +// OpStore %19 %8 +// %20 = OpAccessChain %6 %18 %11 %11 +// OpStore %20 %8 +// %21 = OpAccessChain %30 %18 %12 +// OpStore %21 %14 +// OpReturn +// OpFunctionEnd +// )"; // -// SinglePassRunAndMatch(text, true); -//} +// SinglePassRunAndMatch(text, true); +// } TEST_F(ScalarReplacementTest, ElideAccessChain) { const std::string text = R"( @@ -316,7 +319,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, ElideMultipleAccessChains) { @@ -354,7 +357,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, ReplaceAccessChain) { @@ -396,7 +399,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, ArrayInitialization) { @@ -447,8 +450,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); - ; + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, NonUniformCompositeInitialization) { @@ -470,15 +472,12 @@ TEST_F(ScalarReplacementTest, NonUniformCompositeInitialization) { ; CHECK: [[const_array:%\w+]] = OpConstantComposite [[array]] ; CHECK: [[const_matrix:%\w+]] = OpConstantNull [[matrix]] ; CHECK: [[const_struct1:%\w+]] = OpConstantComposite [[struct1]] -; CHECK: [[vector_ptr:%\w+]] = OpTypePointer Function [[vector]] -; CHECK: [[long_ptr:%\w+]] = OpTypePointer Function [[long]] +; CHECK: OpConstantNull [[uint]] +; CHECK: OpConstantNull [[vector]] +; CHECK: OpConstantNull [[long]] +; CHECK: OpFunction ; CHECK-NOT: OpVariable [[struct2_ptr]] Function -; CHECK: OpVariable [[long_ptr]] Function -; CHECK: OpVariable [[long_ptr]] Function -; CHECK: OpVariable [[long_ptr]] Function -; CHECK: OpVariable [[vector_ptr]] Function -; CHECK: OpVariable [[uint_ptr]] Function -; CHECK: OpVariable [[uint_ptr]] Function +; CHECK: OpVariable [[uint_ptr]] Function ; CHECK-NEXT: OpVariable [[matrix_ptr]] Function [[const_matrix]] ; CHECK-NOT: OpVariable [[struct1_ptr]] Function [[const_struct1]] ; CHECK-NOT: OpVariable [[struct2_ptr]] Function @@ -532,8 +531,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); - ; + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, ElideUncombinedAccessChains) { @@ -569,7 +567,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, ElideSingleUncombinedAccessChains) { @@ -609,7 +607,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, ReplaceWholeLoad) { @@ -635,6 +633,7 @@ OpName %func "replace_whole_load" %uint_ptr = OpTypePointer Function %uint %struct1_ptr = OpTypePointer Function %struct1 %uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 %func = OpTypeFunction %void %1 = OpFunction %void None %func %2 = OpLabel @@ -642,11 +641,13 @@ OpName %func "replace_whole_load" %load = OpLoad %struct1 %var %3 = OpAccessChain %uint_ptr %var %uint_0 OpStore %3 %uint_0 +%4 = OpAccessChain %uint_ptr %var %uint_1 +OpStore %4 %uint_0 OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, ReplaceWholeLoadCopyMemoryAccess) { @@ -656,11 +657,10 @@ TEST_F(ScalarReplacementTest, ReplaceWholeLoadCopyMemoryAccess) { ; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]] ; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] ; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 -; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[null:%\w+]] = OpConstantNull [[uint]] ; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function -; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]] Nontemporal ; CHECK: [[l0:%\w+]] = OpLoad [[uint]] [[var0]] Nontemporal -; CHECK: OpCompositeConstruct [[struct1]] [[l0]] [[l1]] +; CHECK: OpCompositeConstruct [[struct1]] [[l0]] [[null]] ; OpCapability Shader OpCapability Linkage @@ -683,7 +683,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, ReplaceWholeStore) { @@ -694,12 +694,9 @@ TEST_F(ScalarReplacementTest, ReplaceWholeStore) { ; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] ; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 ; CHECK: [[const_struct:%\w+]] = OpConstantComposite [[struct1]] [[const]] [[const]] -; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function ; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function ; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[uint]] [[const_struct]] 0 ; CHECK: OpStore [[var0]] [[ex0]] -; CHECK: [[ex1:%\w+]] = OpCompositeExtract [[uint]] [[const_struct]] 1 -; CHECK: OpStore [[var1]] [[ex1]] ; OpCapability Shader OpCapability Linkage @@ -723,7 +720,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, ReplaceWholeStoreCopyMemoryAccess) { @@ -734,12 +731,10 @@ TEST_F(ScalarReplacementTest, ReplaceWholeStoreCopyMemoryAccess) { ; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] ; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 ; CHECK: [[const_struct:%\w+]] = OpConstantComposite [[struct1]] [[const]] [[const]] -; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function ; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK-NOT: OpVariable ; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[uint]] [[const_struct]] 0 ; CHECK: OpStore [[var0]] [[ex0]] Aligned 4 -; CHECK: [[ex1:%\w+]] = OpCompositeExtract [[uint]] [[const_struct]] 1 -; CHECK: OpStore [[var1]] [[ex1]] Aligned 4 ; OpCapability Shader OpCapability Linkage @@ -763,7 +758,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, DontTouchVolatileLoad) { @@ -795,7 +790,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, DontTouchVolatileStore) { @@ -827,7 +822,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, DontTouchSpecNonFunctionVariable) { @@ -859,7 +854,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, DontTouchSpecConstantAccessChain) { @@ -893,7 +888,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, NoPartialAccesses) { @@ -902,7 +897,6 @@ TEST_F(ScalarReplacementTest, NoPartialAccesses) { ; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 ; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] ; CHECK: OpLabel -; CHECK-NEXT: OpVariable [[uint_ptr]] ; CHECK-NOT: OpVariable ; OpCapability Shader @@ -924,7 +918,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, DontTouchPtrAccessChain) { @@ -958,7 +952,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(ScalarReplacementTest, DontTouchInBoundsPtrAccessChain) { @@ -992,7 +986,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(ScalarReplacementTest, DonTouchAliasedDecoration) { @@ -1025,7 +1019,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, CopyRestrictDecoration) { @@ -1067,7 +1061,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, DontClobberDecoratesOnSubtypes) { @@ -1105,7 +1099,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, DontCopyMemberDecorate) { @@ -1142,7 +1136,7 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); } TEST_F(ScalarReplacementTest, NoPartialAccesses2) { @@ -1266,8 +1260,268 @@ OpReturn OpFunctionEnd )"; - SinglePassRunAndMatch(text, true); + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, ReplaceWholeLoadAndStore) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[null:%\w+]] = OpConstantNull [[uint]] +; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK-NOT: OpVariable +; CHECK: [[l0:%\w+]] = OpLoad [[uint]] [[var0]] +; CHECK: [[c0:%\w+]] = OpCompositeConstruct [[struct1]] [[l0]] [[null]] +; CHECK: [[e0:%\w+]] = OpCompositeExtract [[uint]] [[c0]] 0 +; CHECK: OpStore [[var1]] [[e0]] +; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]] +; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[null]] +; CHECK: [[e1:%\w+]] = OpCompositeExtract [[uint]] [[c1]] 0 +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "replace_whole_load" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var2 = OpVariable %struct1_ptr Function +%var1 = OpVariable %struct1_ptr Function +%load1 = OpLoad %struct1 %var1 +OpStore %var2 %load1 +%load2 = OpLoad %struct1 %var2 +%3 = OpCompositeExtract %uint %load2 0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, ReplaceWholeLoadAndStore2) { + // TODO: We can improve this case by ensuring that |var2| is processed first. + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[null:%\w+]] = OpConstantNull [[uint]] +; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var0a:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var0b:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK-NOT: OpVariable +; CHECK: [[l0a:%\w+]] = OpLoad [[uint]] [[var0a]] +; CHECK: [[l0b:%\w+]] = OpLoad [[uint]] [[var0b]] +; CHECK: [[c0:%\w+]] = OpCompositeConstruct [[struct1]] [[l0b]] [[l0a]] +; CHECK: [[e0:%\w+]] = OpCompositeExtract [[uint]] [[c0]] 0 +; CHECK: OpStore [[var1]] [[e0]] +; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]] +; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[null]] +; CHECK: [[e1:%\w+]] = OpCompositeExtract [[uint]] [[c1]] 0 +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "replace_whole_load" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var1 = OpVariable %struct1_ptr Function +%var2 = OpVariable %struct1_ptr Function +%load1 = OpLoad %struct1 %var1 +OpStore %var2 %load1 +%load2 = OpLoad %struct1 %var2 +%3 = OpCompositeExtract %uint %load2 0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, CreateAmbiguousNullConstant1) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[struct_member:%\w+]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[null:%\w+]] = OpConstantNull [[struct_member]] +; CHECK: [[var0a:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var0b:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK-NOT: OpVariable +; CHECK: OpStore [[var1]] +; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]] +; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[null]] +; CHECK: [[e1:%\w+]] = OpCompositeExtract [[uint]] [[c1]] 0 +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "replace_whole_load" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct2 = OpTypeStruct %uint +%struct3 = OpTypeStruct %uint +%struct1 = OpTypeStruct %uint %struct2 +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var1 = OpVariable %struct1_ptr Function +%var2 = OpVariable %struct1_ptr Function +%load1 = OpLoad %struct1 %var1 +OpStore %var2 %load1 +%load2 = OpLoad %struct1 %var2 +%3 = OpCompositeExtract %uint %load2 0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, CreateAmbiguousNullConstant2) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[struct_member:%\w+]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[null:%\w+]] = OpConstantNull [[struct_member]] +; CHECK: [[var0a:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var0b:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: OpStore [[var1]] +; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]] +; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[null]] +; CHECK: [[e1:%\w+]] = OpCompositeExtract [[uint]] [[c1]] 0 +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "replace_whole_load" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct3 = OpTypeStruct %uint +%struct2 = OpTypeStruct %uint +%struct1 = OpTypeStruct %uint %struct2 +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var1 = OpVariable %struct1_ptr Function +%var2 = OpVariable %struct1_ptr Function +%load1 = OpLoad %struct1 %var1 +OpStore %var2 %load1 +%load2 = OpLoad %struct1 %var2 +%3 = OpCompositeExtract %uint %load2 0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); } #endif // SPIRV_EFFCEE +// Test that a struct of size 4 is not replaced when there is a limit of 2. +TEST_F(ScalarReplacementTest, TestLimit) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %6 "simple_struct" +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeStruct %2 %2 %2 %2 +%4 = OpTypePointer Function %3 +%5 = OpTypePointer Function %2 +%6 = OpTypeFunction %2 +%7 = OpConstantNull %3 +%8 = OpConstant %2 0 +%9 = OpConstant %2 1 +%10 = OpConstant %2 2 +%11 = OpConstant %2 3 +%12 = OpFunction %2 None %6 +%13 = OpLabel +%14 = OpVariable %4 Function %7 +%15 = OpInBoundsAccessChain %5 %14 %8 +%16 = OpLoad %2 %15 +%17 = OpAccessChain %5 %14 %10 +%18 = OpLoad %2 %17 +%19 = OpIAdd %2 %16 %18 +OpReturnValue %19 +OpFunctionEnd + )"; + + auto result = + SinglePassRunAndDisassemble(text, true, false, 2); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +// Test that a struct of size 4 is replaced when there is a limit of 0 (no +// limit). This is the same spir-v as a test above, so we do not check that it +// is correctly transformed. We leave that to the test above. +TEST_F(ScalarReplacementTest, TestUnimited) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %6 "simple_struct" +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeStruct %2 %2 %2 %2 +%4 = OpTypePointer Function %3 +%5 = OpTypePointer Function %2 +%6 = OpTypeFunction %2 +%7 = OpConstantNull %3 +%8 = OpConstant %2 0 +%9 = OpConstant %2 1 +%10 = OpConstant %2 2 +%11 = OpConstant %2 3 +%12 = OpFunction %2 None %6 +%13 = OpLabel +%14 = OpVariable %4 Function %7 +%15 = OpInBoundsAccessChain %5 %14 %8 +%16 = OpLoad %2 %15 +%17 = OpAccessChain %5 %14 %10 +%18 = OpLoad %2 %17 +%19 = OpIAdd %2 %16 %18 +OpReturnValue %19 +OpFunctionEnd + )"; + + auto result = + SinglePassRunAndDisassemble(text, true, false, 0); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); +} + } // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/set_spec_const_default_value_test.cpp b/3rdparty/spirv-tools/test/opt/set_spec_const_default_value_test.cpp index 973b5e494..161674fe0 100644 --- a/3rdparty/spirv-tools/test/opt/set_spec_const_default_value_test.cpp +++ b/3rdparty/spirv-tools/test/opt/set_spec_const_default_value_test.cpp @@ -12,19 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" +#include -#include +#include "gmock/gmock.h" +#include "test/opt/pass_fixture.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; using testing::Eq; - using SpecIdToValueStrMap = - opt::SetSpecConstantDefaultValuePass::SpecIdToValueStrMap; + SetSpecConstantDefaultValuePass::SpecIdToValueStrMap; using SpecIdToValueBitPatternMap = - opt::SetSpecConstantDefaultValuePass::SpecIdToValueBitPatternMap; + SetSpecConstantDefaultValuePass::SpecIdToValueBitPatternMap; struct DefaultValuesStringParsingTestCase { const char* default_values_str; @@ -37,9 +38,8 @@ using DefaultValuesStringParsingTest = TEST_P(DefaultValuesStringParsingTest, TestCase) { const auto& tc = GetParam(); - auto actual_map = - opt::SetSpecConstantDefaultValuePass::ParseDefaultValuesString( - tc.default_values_str); + auto actual_map = SetSpecConstantDefaultValuePass::ParseDefaultValuesString( + tc.default_values_str); if (tc.expect_success) { EXPECT_NE(nullptr, actual_map); if (actual_map) { @@ -144,7 +144,7 @@ using SetSpecConstantDefaultValueInStringFormParamTest = PassTest< TEST_P(SetSpecConstantDefaultValueInStringFormParamTest, TestCase) { const auto& tc = GetParam(); - SinglePassRunAndCheck( + SinglePassRunAndCheck( tc.code, tc.expected, /* skip_nop = */ false, tc.default_values); } @@ -606,7 +606,7 @@ using SetSpecConstantDefaultValueInBitPatternFormParamTest = TEST_P(SetSpecConstantDefaultValueInBitPatternFormParamTest, TestCase) { const auto& tc = GetParam(); - SinglePassRunAndCheck( + SinglePassRunAndCheck( tc.code, tc.expected, /* skip_nop = */ false, tc.default_values); } @@ -1072,4 +1072,6 @@ INSTANTIATE_TEST_CASE_P( }, })); -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/simplification_test.cpp b/3rdparty/spirv-tools/test/opt/simplification_test.cpp index 6e1ff23fd..b5ad26790 100644 --- a/3rdparty/spirv-tools/test/opt/simplification_test.cpp +++ b/3rdparty/spirv-tools/test/opt/simplification_test.cpp @@ -12,16 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "opt/simplification_pass.h" +#include -#include "assembly_builder.h" #include "gmock/gmock.h" -#include "pass_fixture.h" +#include "source/opt/simplification_pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using SimplificationTest = PassTest<::testing::Test>; #ifdef SPIRV_EFFCEE @@ -68,7 +69,7 @@ TEST_F(SimplificationTest, StraightLineTest) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(SimplificationTest, AcrossBasicBlocks) { @@ -132,7 +133,7 @@ TEST_F(SimplificationTest, AcrossBasicBlocks) { )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(SimplificationTest, ThroughLoops) { @@ -199,7 +200,11 @@ TEST_F(SimplificationTest, ThroughLoops) { OpFunctionEnd )"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } + #endif -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/strength_reduction_test.cpp b/3rdparty/spirv-tools/test/opt/strength_reduction_test.cpp index 3ae785a1d..654c90df9 100644 --- a/3rdparty/spirv-tools/test/opt/strength_reduction_test.cpp +++ b/3rdparty/spirv-tools/test/opt/strength_reduction_test.cpp @@ -12,24 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "assembly_builder.h" -#include "gmock/gmock.h" -#include "pass_fixture.h" -#include "pass_utils.h" - #include #include #include #include +#include #include +#include +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using ::testing::HasSubstr; using ::testing::MatchesRegex; - using StrengthReductionBasicTest = PassTest<::testing::Test>; // Test to make sure we replace 5*8. @@ -54,10 +55,10 @@ TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy8) { // clang-format on }; - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); const std::string& output = std::get<0>(result); EXPECT_THAT(output, Not(HasSubstr("OpIMul"))); EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_3")); @@ -99,7 +100,7 @@ TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy16) { ; CHECK: OpFunctionEnd OpFunctionEnd)"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } #endif @@ -126,10 +127,10 @@ TEST_F(StrengthReductionBasicTest, BasicTwoPowersOf2) { OpFunctionEnd )"; // clang-format on - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( text, /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); const std::string& output = std::get<0>(result); EXPECT_THAT(output, Not(HasSubstr("OpIMul"))); EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %int %int_4 %uint_5")); @@ -157,10 +158,10 @@ TEST_F(StrengthReductionBasicTest, BasicDontReplace0) { // clang-format on }; - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } // Test to make sure we do not replace a multiple of 5 and 7. @@ -186,10 +187,10 @@ TEST_F(StrengthReductionBasicTest, BasicNoChange) { // clang-format on }; - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } // Test to make sure constants and types are reused and not duplicated. @@ -214,10 +215,10 @@ TEST_F(StrengthReductionBasicTest, NoDuplicateConstantsAndTypes) { // clang-format on }; - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); const std::string& output = std::get<0>(result); EXPECT_THAT(output, Not(MatchesRegex(".*OpConstant %uint 3.*OpConstant %uint 3.*"))); @@ -248,10 +249,10 @@ TEST_F(StrengthReductionBasicTest, BasicCreateOneConst) { // clang-format on }; - auto result = SinglePassRunAndDisassemble( + auto result = SinglePassRunAndDisassemble( JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false); - EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result)); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); const std::string& output = std::get<0>(result); EXPECT_THAT(output, Not(HasSubstr("OpIMul"))); EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_7")); @@ -339,7 +340,7 @@ TEST_F(StrengthReductionBasicTest, BasicCheckPositionAndReplacement) { }; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(Concat(common_text, foo_before)), JoinAllInsts(Concat(common_text, foo_after)), /* skip_nop = */ true, /* do_validate = */ true); @@ -428,9 +429,12 @@ TEST_F(StrengthReductionBasicTest, BasicTestMultipleReplacements) { }; SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - SinglePassRunAndCheck( + SinglePassRunAndCheck( JoinAllInsts(Concat(common_text, foo_before)), JoinAllInsts(Concat(common_text, foo_after)), /* skip_nop = */ true, /* do_validate = */ true); } -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/strip_debug_info_test.cpp b/3rdparty/spirv-tools/test/opt/strip_debug_info_test.cpp index 8cae51e9b..f40ed382a 100644 --- a/3rdparty/spirv-tools/test/opt/strip_debug_info_test.cpp +++ b/3rdparty/spirv-tools/test/opt/strip_debug_info_test.cpp @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using StripLineDebugInfoTest = PassTest<::testing::Test>; TEST_F(StripLineDebugInfoTest, LineNoLine) { @@ -51,9 +53,9 @@ TEST_F(StripLineDebugInfoTest, LineNoLine) { "OpFunctionEnd", // clang-format on }; - SinglePassRunAndCheck(JoinAllInsts(text), - JoinNonDebugInsts(text), - /* skip_nop = */ false); + SinglePassRunAndCheck(JoinAllInsts(text), + JoinNonDebugInsts(text), + /* skip_nop = */ false); // Let's add more debug instruction before the "OpString" instruction. const std::vector more_text = { @@ -67,9 +69,9 @@ TEST_F(StripLineDebugInfoTest, LineNoLine) { "OpName %2 \"main\"", }; text.insert(text.begin() + 4, more_text.cbegin(), more_text.cend()); - SinglePassRunAndCheck(JoinAllInsts(text), - JoinNonDebugInsts(text), - /* skip_nop = */ false); + SinglePassRunAndCheck(JoinAllInsts(text), + JoinNonDebugInsts(text), + /* skip_nop = */ false); } using StripDebugInfoTest = PassTest<::testing::TestWithParam>; @@ -80,9 +82,9 @@ TEST_P(StripDebugInfoTest, Kind) { "OpMemoryModel Logical GLSL450", GetParam(), }; - SinglePassRunAndCheck(JoinAllInsts(text), - JoinNonDebugInsts(text), - /* skip_nop = */ false); + SinglePassRunAndCheck(JoinAllInsts(text), + JoinNonDebugInsts(text), + /* skip_nop = */ false); } // Test each possible non-line debug instruction. @@ -100,4 +102,6 @@ INSTANTIATE_TEST_CASE_P( }))); // clang-format on -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/strip_reflect_info_test.cpp b/3rdparty/spirv-tools/test/opt/strip_reflect_info_test.cpp index 3cf2cdce7..088cac7aa 100644 --- a/3rdparty/spirv-tools/test/opt/strip_reflect_info_test.cpp +++ b/3rdparty/spirv-tools/test/opt/strip_reflect_info_test.cpp @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "pass_fixture.h" -#include "pass_utils.h" +#include +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using StripLineReflectInfoTest = PassTest<::testing::Test>; TEST_F(StripLineReflectInfoTest, StripHlslSemantic) { @@ -40,7 +42,7 @@ OpMemoryModel Logical Simple %float = OpTypeFloat 32 )"; - SinglePassRunAndCheck(before, after, false); + SinglePassRunAndCheck(before, after, false); } TEST_F(StripLineReflectInfoTest, StripHlslCounterBuffer) { @@ -59,7 +61,9 @@ OpMemoryModel Logical Simple %float = OpTypeFloat 32 )"; - SinglePassRunAndCheck(before, after, false); + SinglePassRunAndCheck(before, after, false); } -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/type_manager_test.cpp b/3rdparty/spirv-tools/test/opt/type_manager_test.cpp index 8832b75f1..cf1fcb583 100644 --- a/3rdparty/spirv-tools/test/opt/type_manager_test.cpp +++ b/3rdparty/spirv-tools/test/opt/type_manager_test.cpp @@ -12,23 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/instruction.h" +#include "source/opt/type_manager.h" +#include "spirv-tools/libspirv.hpp" #ifdef SPIRV_EFFCEE #include "effcee/effcee.h" #endif -#include "opt/build_module.h" -#include "opt/instruction.h" -#include "opt/type_manager.h" -#include "spirv-tools/libspirv.hpp" - +namespace spvtools { +namespace opt { +namespace analysis { namespace { -using namespace spvtools; -using namespace spvtools::opt::analysis; - #ifdef SPIRV_EFFCEE bool Validate(const std::vector& bin) { @@ -43,7 +47,7 @@ bool Validate(const std::vector& bin) { return error == 0; } -void Match(const std::string& original, ir::IRContext* context, +void Match(const std::string& original, IRContext* context, bool do_validation = true) { std::vector bin; context->module()->ToBinary(&bin, true); @@ -131,10 +135,11 @@ std::vector> GenerateAllTypes() { auto* rav3s32 = types.back().get(); // Struct - types.emplace_back(new Struct(std::vector{s32})); - types.emplace_back(new Struct(std::vector{s32, f32})); + types.emplace_back(new Struct(std::vector{s32})); + types.emplace_back(new Struct(std::vector{s32, f32})); auto* sts32f32 = types.back().get(); - types.emplace_back(new Struct(std::vector{u64, a42f32, rav3s32})); + types.emplace_back( + new Struct(std::vector{u64, a42f32, rav3s32})); // Opaque types.emplace_back(new Opaque("")); @@ -173,7 +178,6 @@ std::vector> GenerateAllTypes() { TEST(TypeManager, TypeStrings) { const std::string text = R"( OpTypeForwardPointer !20 !2 ; id for %p is 20, Uniform is 2 - OpTypeForwardPointer !10000 !1 %void = OpTypeVoid %bool = OpTypeBool %u32 = OpTypeInt 32 0 @@ -235,19 +239,237 @@ TEST(TypeManager, TypeStrings) { {28, "named_barrier"}, }; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, context.get()); + TypeManager manager(nullptr, context.get()); EXPECT_EQ(type_id_strs.size(), manager.NumTypes()); - EXPECT_EQ(2u, manager.NumForwardPointers()); for (const auto& p : type_id_strs) { EXPECT_EQ(p.second, manager.GetType(p.first)->str()); EXPECT_EQ(p.first, manager.GetId(manager.GetType(p.first))); } - EXPECT_EQ("forward_pointer({uint32}*)", manager.GetForwardPointer(0)->str()); - EXPECT_EQ("forward_pointer(10000)", manager.GetForwardPointer(1)->str()); +} + +TEST(TypeManager, StructWithFwdPtr) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %7 "test" + OpSource OpenCL_C 102000 + OpDecorate %11 FuncParamAttr NoCapture + %11 = OpDecorationGroup + OpGroupDecorate %11 %8 %9 + OpTypeForwardPointer %100 CrossWorkgroup + %void = OpTypeVoid + %150 = OpTypeStruct %100 +%100 = OpTypePointer CrossWorkgroup %150 + %6 = OpTypeFunction %void %100 %100 + %7 = OpFunction %void Pure %6 + %8 = OpFunctionParameter %100 + %9 = OpFunctionParameter %100 + %10 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + TypeManager manager(nullptr, context.get()); + + Type* p100 = manager.GetType(100); + Type* s150 = manager.GetType(150); + + EXPECT_TRUE(p100->AsPointer()); + EXPECT_EQ(p100->AsPointer()->pointee_type(), s150); + + EXPECT_TRUE(s150->AsStruct()); + EXPECT_EQ(s150->AsStruct()->element_types()[0], p100); +} + +TEST(TypeManager, CircularFwdPtr) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %7 "test" + OpSource OpenCL_C 102000 + OpDecorate %11 FuncParamAttr NoCapture + %11 = OpDecorationGroup + OpGroupDecorate %11 %8 %9 + OpTypeForwardPointer %100 CrossWorkgroup + OpTypeForwardPointer %200 CrossWorkgroup + %void = OpTypeVoid + %int = OpTypeInt 32 0 + %float = OpTypeFloat 32 + %150 = OpTypeStruct %200 %int + %250 = OpTypeStruct %100 %float +%100 = OpTypePointer CrossWorkgroup %150 +%200 = OpTypePointer CrossWorkgroup %250 + %6 = OpTypeFunction %void %100 %200 + %7 = OpFunction %void Pure %6 + %8 = OpFunctionParameter %100 + %9 = OpFunctionParameter %200 + %10 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + TypeManager manager(nullptr, context.get()); + + Type* p100 = manager.GetType(100); + Type* s150 = manager.GetType(150); + Type* p200 = manager.GetType(200); + Type* s250 = manager.GetType(250); + + EXPECT_TRUE(p100->AsPointer()); + EXPECT_EQ(p100->AsPointer()->pointee_type(), s150); + + EXPECT_TRUE(p200->AsPointer()); + EXPECT_EQ(p200->AsPointer()->pointee_type(), s250); + + EXPECT_TRUE(s150->AsStruct()); + EXPECT_EQ(s150->AsStruct()->element_types()[0], p200); + + EXPECT_TRUE(s250->AsStruct()); + EXPECT_EQ(s250->AsStruct()->element_types()[0], p100); +} + +TEST(TypeManager, IsomorphicStructWithFwdPtr) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %7 "test" + OpSource OpenCL_C 102000 + OpDecorate %11 FuncParamAttr NoCapture + %11 = OpDecorationGroup + OpGroupDecorate %11 %8 %9 + OpTypeForwardPointer %100 CrossWorkgroup + OpTypeForwardPointer %200 CrossWorkgroup + %void = OpTypeVoid + %_struct_1 = OpTypeStruct %100 + %_struct_2 = OpTypeStruct %200 +%100 = OpTypePointer CrossWorkgroup %_struct_1 +%200 = OpTypePointer CrossWorkgroup %_struct_2 + %6 = OpTypeFunction %void %100 %200 + %7 = OpFunction %void Pure %6 + %8 = OpFunctionParameter %100 + %9 = OpFunctionParameter %200 + %10 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + TypeManager manager(nullptr, context.get()); + + EXPECT_EQ(manager.GetType(100), manager.GetType(200)); +} + +TEST(TypeManager, IsomorphicCircularFwdPtr) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %7 "test" + OpSource OpenCL_C 102000 + OpDecorate %11 FuncParamAttr NoCapture + %11 = OpDecorationGroup + OpGroupDecorate %11 %8 %9 + OpTypeForwardPointer %100 CrossWorkgroup + OpTypeForwardPointer %200 CrossWorkgroup + OpTypeForwardPointer %300 CrossWorkgroup + OpTypeForwardPointer %400 CrossWorkgroup + %void = OpTypeVoid + %int = OpTypeInt 32 0 + %float = OpTypeFloat 32 + %150 = OpTypeStruct %200 %int + %250 = OpTypeStruct %100 %float + %350 = OpTypeStruct %400 %int + %450 = OpTypeStruct %300 %float +%100 = OpTypePointer CrossWorkgroup %150 +%200 = OpTypePointer CrossWorkgroup %250 +%300 = OpTypePointer CrossWorkgroup %350 +%400 = OpTypePointer CrossWorkgroup %450 + %6 = OpTypeFunction %void %100 %200 + %7 = OpFunction %void Pure %6 + %8 = OpFunctionParameter %100 + %9 = OpFunctionParameter %200 + %10 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + TypeManager manager(nullptr, context.get()); + + Type* p100 = manager.GetType(100); + Type* p300 = manager.GetType(300); + EXPECT_EQ(p100, p300); + Type* p200 = manager.GetType(200); + Type* p400 = manager.GetType(400); + EXPECT_EQ(p200, p400); + + Type* p150 = manager.GetType(150); + Type* p350 = manager.GetType(350); + EXPECT_EQ(p150, p350); + Type* p250 = manager.GetType(250); + Type* p450 = manager.GetType(450); + EXPECT_EQ(p250, p450); +} + +TEST(TypeManager, PartialIsomorphicFwdPtr) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %7 "test" + OpSource OpenCL_C 102000 + OpDecorate %11 FuncParamAttr NoCapture + %11 = OpDecorationGroup + OpGroupDecorate %11 %8 %9 + OpTypeForwardPointer %100 CrossWorkgroup + OpTypeForwardPointer %200 CrossWorkgroup + %void = OpTypeVoid + %int = OpTypeInt 32 0 + %float = OpTypeFloat 32 + %150 = OpTypeStruct %200 %int + %250 = OpTypeStruct %200 %int +%100 = OpTypePointer CrossWorkgroup %150 +%200 = OpTypePointer CrossWorkgroup %250 + %6 = OpTypeFunction %void %100 %200 + %7 = OpFunction %void Pure %6 + %8 = OpFunctionParameter %100 + %9 = OpFunctionParameter %200 + %10 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + TypeManager manager(nullptr, context.get()); + + Type* p100 = manager.GetType(100); + Type* p200 = manager.GetType(200); + EXPECT_EQ(p100->AsPointer()->pointee_type(), + p200->AsPointer()->pointee_type()); } TEST(TypeManager, DecorationOnStruct) { @@ -265,12 +487,11 @@ TEST(TypeManager, DecorationOnStruct) { %struct4 = OpTypeStruct %u32 %f32 ; the same %struct7 = OpTypeStruct %f32 ; no decoration )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, context.get()); + TypeManager manager(nullptr, context.get()); ASSERT_EQ(7u, manager.NumTypes()); - ASSERT_EQ(0u, manager.NumForwardPointers()); // Make sure we get ids correct. ASSERT_EQ("uint32", manager.GetType(5)->str()); ASSERT_EQ("float32", manager.GetType(6)->str()); @@ -315,12 +536,11 @@ TEST(TypeManager, DecorationOnMember) { %struct7 = OpTypeStruct %u32 %f32 ; extra decoration on the struct %struct10 = OpTypeStruct %u32 %f32 ; no member decoration )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, context.get()); + TypeManager manager(nullptr, context.get()); ASSERT_EQ(10u, manager.NumTypes()); - ASSERT_EQ(0u, manager.NumForwardPointers()); // Make sure we get ids correct. ASSERT_EQ("uint32", manager.GetType(8)->str()); ASSERT_EQ("float32", manager.GetType(9)->str()); @@ -353,12 +573,11 @@ TEST(TypeManager, DecorationEmpty) { %struct2 = OpTypeStruct %f32 %u32 %struct5 = OpTypeStruct %f32 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, context.get()); + TypeManager manager(nullptr, context.get()); ASSERT_EQ(5u, manager.NumTypes()); - ASSERT_EQ(0u, manager.NumForwardPointers()); // Make sure we get ids correct. ASSERT_EQ("uint32", manager.GetType(3)->str()); ASSERT_EQ("float32", manager.GetType(4)->str()); @@ -375,11 +594,10 @@ TEST(TypeManager, DecorationEmpty) { TEST(TypeManager, BeginEndForEmptyModule) { const std::string text = ""; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, context.get()); + TypeManager manager(nullptr, context.get()); ASSERT_EQ(0u, manager.NumTypes()); - ASSERT_EQ(0u, manager.NumForwardPointers()); EXPECT_EQ(manager.begin(), manager.end()); } @@ -392,11 +610,10 @@ TEST(TypeManager, BeginEnd) { %u32 = OpTypeInt 32 0 %f64 = OpTypeFloat 64 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, context.get()); + TypeManager manager(nullptr, context.get()); ASSERT_EQ(5u, manager.NumTypes()); - ASSERT_EQ(0u, manager.NumForwardPointers()); EXPECT_NE(manager.begin(), manager.end()); for (const auto& t : manager) { @@ -429,7 +646,7 @@ TEST(TypeManager, LookupType) { %vec2 = OpTypeVector %int 2 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); EXPECT_NE(context, nullptr); TypeManager manager(nullptr, context.get()); @@ -457,7 +674,7 @@ OpMemoryModel Logical GLSL450 %2 = OpTypeInt 32 1 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(context, nullptr); @@ -480,7 +697,7 @@ OpMemoryModel Logical GLSL450 %2 = OpTypeStruct %1 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(context, nullptr); @@ -503,7 +720,7 @@ OpMemoryModel Logical GLSL450 %3 = OpTypeStruct %1 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(context, nullptr); @@ -528,7 +745,7 @@ OpMemoryModel Logical GLSL450 %3 = OpTypeStruct %1 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(context, nullptr); @@ -555,7 +772,7 @@ OpMemoryModel Logical GLSL450 %2 = OpTypeStruct %1 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(context, nullptr); @@ -589,7 +806,7 @@ OpMemoryModel Logical GLSL450 %2 = OpTypeInt 32 0 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(context, nullptr); @@ -611,7 +828,7 @@ OpDecorate %3 Constant %3 = OpTypeStruct %1 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(context, nullptr); @@ -631,7 +848,7 @@ OpMemoryModel Logical GLSL450 %2 = OpTypeStruct %1 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(context, nullptr); @@ -651,7 +868,7 @@ OpMemoryModel Logical GLSL450 %1 = OpTypeInt 32 0 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(context, nullptr); @@ -674,7 +891,7 @@ OpCapability Shader OpMemoryModel Logical GLSL450 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(context, nullptr); @@ -700,7 +917,7 @@ OpMemoryModel Logical GLSL450 %1 = OpTypeInt 32 0 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(context, nullptr); @@ -733,7 +950,7 @@ OpCapability Linkage OpMemoryModel Logical GLSL450 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); EXPECT_NE(context, nullptr); @@ -755,7 +972,7 @@ OpCapability Linkage OpMemoryModel Logical GLSL450 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); EXPECT_NE(context, nullptr); @@ -838,7 +1055,7 @@ OpMemoryModel Logical GLSL450 %100 = OpConstant %uint 100 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(context, nullptr); @@ -864,7 +1081,7 @@ OpMemoryModel Logical GLSL450 %uint = OpTypeInt 32 0 )"; - std::unique_ptr context = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); EXPECT_NE(context, nullptr); @@ -878,6 +1095,59 @@ OpMemoryModel Logical GLSL450 Match(text, context.get()); } + +TEST(TypeManager, GetPointerToAmbiguousType1) { + const std::string text = R"( +; CHECK: [[struct1:%\w+]] = OpTypeStruct +; CHECK: [[struct2:%\w+]] = OpTypeStruct +; CHECK: OpTypePointer Function [[struct2]] +; CHECK: OpTypePointer Function [[struct1]] +OpCapability Shader +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%uint = OpTypeInt 32 0 +%1 = OpTypeStruct %uint +%2 = OpTypeStruct %uint +%3 = OpTypePointer Function %2 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + context->get_type_mgr()->FindPointerToType(1, SpvStorageClassFunction); + Match(text, context.get()); +} + +TEST(TypeManager, GetPointerToAmbiguousType2) { + const std::string text = R"( +; CHECK: [[struct1:%\w+]] = OpTypeStruct +; CHECK: [[struct2:%\w+]] = OpTypeStruct +; CHECK: OpTypePointer Function [[struct1]] +; CHECK: OpTypePointer Function [[struct2]] +OpCapability Shader +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%uint = OpTypeInt 32 0 +%1 = OpTypeStruct %uint +%2 = OpTypeStruct %uint +%3 = OpTypePointer Function %1 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + context->get_type_mgr()->FindPointerToType(2, SpvStorageClassFunction); + Match(text, context.get()); +} #endif // SPIRV_EFFCEE -} // anonymous namespace +} // namespace +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/types_test.cpp b/3rdparty/spirv-tools/test/opt/types_test.cpp index 1f081d437..c11187e83 100644 --- a/3rdparty/spirv-tools/test/opt/types_test.cpp +++ b/3rdparty/spirv-tools/test/opt/types_test.cpp @@ -13,28 +13,29 @@ // limitations under the License. #include +#include #include -#include - -#include "opt/make_unique.h" -#include "opt/types.h" +#include "gtest/gtest.h" +#include "source/opt/types.h" +#include "source/util/make_unique.h" +namespace spvtools { +namespace opt { +namespace analysis { namespace { -using namespace spvtools::opt::analysis; -using spvtools::MakeUnique; - // Fixture class providing some element types. class SameTypeTest : public ::testing::Test { protected: - virtual void SetUp() override { - void_t_.reset(new Void()); - u32_t_.reset(new Integer(32, false)); - f64_t_.reset(new Float(64)); - v3u32_t_.reset(new Vector(u32_t_.get(), 3)); - image_t_.reset(new Image(f64_t_.get(), SpvDim2D, 1, 1, 0, 0, - SpvImageFormatR16, SpvAccessQualifierReadWrite)); + void SetUp() override { + void_t_ = MakeUnique(); + u32_t_ = MakeUnique(32, false); + f64_t_ = MakeUnique(64); + v3u32_t_ = MakeUnique(u32_t_.get(), 3); + image_t_ = + MakeUnique(f64_t_.get(), SpvDim2D, 1, 1, 0, 0, SpvImageFormatR16, + SpvAccessQualifierReadWrite); } // Element types to be used for constructing other types for testing. @@ -73,8 +74,8 @@ TestMultipleInstancesOfTheSameType(Sampler); TestMultipleInstancesOfTheSameType(SampledImage, image_t_.get()); TestMultipleInstancesOfTheSameType(Array, u32_t_.get(), 10); TestMultipleInstancesOfTheSameType(RuntimeArray, u32_t_.get()); -TestMultipleInstancesOfTheSameType(Struct, std::vector{u32_t_.get(), - f64_t_.get()}); +TestMultipleInstancesOfTheSameType(Struct, std::vector{ + u32_t_.get(), f64_t_.get()}); TestMultipleInstancesOfTheSameType(Opaque, "testing rocks"); TestMultipleInstancesOfTheSameType(Pointer, u32_t_.get(), SpvStorageClassInput); TestMultipleInstancesOfTheSameType(Function, u32_t_.get(), @@ -160,10 +161,11 @@ std::vector> GenerateAllTypes() { auto* rav3s32 = types.back().get(); // Struct - types.emplace_back(new Struct(std::vector{s32})); - types.emplace_back(new Struct(std::vector{s32, f32})); + types.emplace_back(new Struct(std::vector{s32})); + types.emplace_back(new Struct(std::vector{s32, f32})); auto* sts32f32 = types.back().get(); - types.emplace_back(new Struct(std::vector{u64, a42f32, rav3s32})); + types.emplace_back( + new Struct(std::vector{u64, a42f32, rav3s32})); // Opaque types.emplace_back(new Opaque("")); @@ -336,4 +338,7 @@ TEST(Types, RemoveDecorations) { } } -} // anonymous namespace +} // namespace +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/unify_const_test.cpp b/3rdparty/spirv-tools/test/opt/unify_const_test.cpp index 5c29fa79d..37728cc23 100644 --- a/3rdparty/spirv-tools/test/opt/unify_const_test.cpp +++ b/3rdparty/spirv-tools/test/opt/unify_const_test.cpp @@ -12,14 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include +#include +#include -#include "assembly_builder.h" -#include "pass_fixture.h" -#include "pass_utils.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; // Returns the types defining instructions commonly used in many tests. std::vector CommonTypes() { @@ -114,9 +119,9 @@ class UnifyConstantTest : public PassTest { // optimized code std::string optimized_before_strip; - auto status = opt::Pass::Status::SuccessWithoutChange; + auto status = Pass::Status::SuccessWithoutChange; std::tie(optimized_before_strip, status) = - this->template SinglePassRunAndDisassemble( + this->template SinglePassRunAndDisassemble( test_builder.GetCode(), /* skip_nop = */ true, /* do_validation = */ false); std::string optimized_without_opnames; @@ -125,9 +130,9 @@ class UnifyConstantTest : public PassTest { StripOpNameInstructionsToSet(optimized_before_strip); // Flag "status" should be returned correctly. - EXPECT_NE(opt::Pass::Status::Failure, status); + EXPECT_NE(Pass::Status::Failure, status); EXPECT_EQ(expected_without_opnames == original_without_opnames, - status == opt::Pass::Status::SuccessWithoutChange); + status == Pass::Status::SuccessWithoutChange); // Code except OpName instructions should be exactly the same. EXPECT_EQ(expected_without_opnames, optimized_without_opnames); // OpName instructions can be in different order, but the content must be @@ -980,4 +985,6 @@ INSTANTIATE_TEST_CASE_P(Case, UnifyFrontEndConstantParamTest, // clang-format on }))); -} // anonymous namespace +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/utils_test.cpp b/3rdparty/spirv-tools/test/opt/utils_test.cpp index e20e6856f..9bb82a367 100644 --- a/3rdparty/spirv-tools/test/opt/utils_test.cpp +++ b/3rdparty/spirv-tools/test/opt/utils_test.cpp @@ -12,14 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include +#include -#include "pass_utils.h" +#include "gtest/gtest.h" +#include "test/opt/pass_utils.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - TEST(JoinAllInsts, Cases) { EXPECT_EQ("", JoinAllInsts({})); EXPECT_EQ("a\n", JoinAllInsts({"a"})); @@ -43,7 +45,6 @@ TEST(JoinNonDebugInsts, Cases) { "the only remaining string"})); } -namespace { struct SubstringReplacementTestCase { const char* orig_str; const char* find_substr; @@ -51,7 +52,7 @@ struct SubstringReplacementTestCase { const char* expected_str; bool replace_should_succeed; }; -} // namespace + using FindAndReplaceTest = ::testing::TestWithParam; @@ -103,4 +104,7 @@ INSTANTIATE_TEST_CASE_P( {"abc", "a", "aab", "aabbc", true}, {"abc", "abcd", "efg", "abc", false}, }))); -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/value_table_test.cpp b/3rdparty/spirv-tools/test/opt/value_table_test.cpp index 4f42c816e..ef338ae7e 100644 --- a/3rdparty/spirv-tools/test/opt/value_table_test.cpp +++ b/3rdparty/spirv-tools/test/opt/value_table_test.cpp @@ -12,20 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "opt/value_number_table.h" +#include -#include "assembly_builder.h" #include "gmock/gmock.h" -#include "opt/build_module.h" -#include "pass_fixture.h" +#include "source/opt/build_module.h" +#include "source/opt/value_number_table.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +namespace spvtools { +namespace opt { namespace { -using namespace spvtools; - using ::testing::HasSubstr; using ::testing::MatchesRegex; - using ValueTableTest = PassTest<::testing::Test>; TEST_F(ValueTableTest, SameInstructionSameValue) { @@ -49,8 +49,8 @@ TEST_F(ValueTableTest, SameInstructionSameValue) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst = context->get_def_use_mgr()->GetDef(10); + ValueNumberTable vtable(context.get()); + Instruction* inst = context->get_def_use_mgr()->GetDef(10); EXPECT_EQ(vtable.GetValueNumber(inst), vtable.GetValueNumber(inst)); } @@ -76,9 +76,9 @@ TEST_F(ValueTableTest, DifferentInstructionSameValue) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(11); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(11); EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); } @@ -106,9 +106,9 @@ TEST_F(ValueTableTest, SameValueDifferentBlock) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(12); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(12); EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); } @@ -134,9 +134,9 @@ TEST_F(ValueTableTest, DifferentValue) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(11); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(11); EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); } @@ -164,9 +164,9 @@ TEST_F(ValueTableTest, DifferentValueDifferentBlock) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(12); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(12); EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); } @@ -190,8 +190,8 @@ TEST_F(ValueTableTest, SameLoad) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst = context->get_def_use_mgr()->GetDef(9); + ValueNumberTable vtable(context.get()); + Instruction* inst = context->get_def_use_mgr()->GetDef(9); EXPECT_EQ(vtable.GetValueNumber(inst), vtable.GetValueNumber(inst)); } @@ -218,9 +218,9 @@ TEST_F(ValueTableTest, DifferentFunctionLoad) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); } @@ -245,9 +245,9 @@ TEST_F(ValueTableTest, DifferentUniformLoad) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); } @@ -272,9 +272,9 @@ TEST_F(ValueTableTest, DifferentInputLoad) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); } @@ -299,9 +299,9 @@ TEST_F(ValueTableTest, DifferentUniformConstantLoad) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); } @@ -326,9 +326,9 @@ TEST_F(ValueTableTest, DifferentPushConstantLoad) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); } @@ -358,8 +358,8 @@ TEST_F(ValueTableTest, SameCall) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst = context->get_def_use_mgr()->GetDef(10); + ValueNumberTable vtable(context.get()); + Instruction* inst = context->get_def_use_mgr()->GetDef(10); EXPECT_EQ(vtable.GetValueNumber(inst), vtable.GetValueNumber(inst)); } @@ -391,9 +391,9 @@ TEST_F(ValueTableTest, DifferentCall) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(12); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(12); EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); } @@ -422,9 +422,9 @@ TEST_F(ValueTableTest, DifferentTypes) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(11); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(12); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(11); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(12); EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); } @@ -449,9 +449,9 @@ TEST_F(ValueTableTest, CopyObject) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); } @@ -487,10 +487,10 @@ TEST_F(ValueTableTest, PhiTest1) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(13); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(15); - ir::Instruction* phi = context->get_def_use_mgr()->GetDef(16); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(13); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(15); + Instruction* phi = context->get_def_use_mgr()->GetDef(16); EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(phi)); } @@ -528,10 +528,10 @@ TEST_F(ValueTableTest, PhiTest2) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(14); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(16); - ir::Instruction* phi = context->get_def_use_mgr()->GetDef(17); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(14); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(16); + Instruction* phi = context->get_def_use_mgr()->GetDef(17); EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(phi)); EXPECT_NE(vtable.GetValueNumber(inst2), vtable.GetValueNumber(phi)); @@ -573,16 +573,19 @@ TEST_F(ValueTableTest, PhiLoopTest) { OpFunctionEnd )"; auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - opt::ValueNumberTable vtable(context.get()); - ir::Instruction* inst1 = context->get_def_use_mgr()->GetDef(12); - ir::Instruction* inst2 = context->get_def_use_mgr()->GetDef(16); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(12); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(16); EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); - ir::Instruction* phi1 = context->get_def_use_mgr()->GetDef(15); + Instruction* phi1 = context->get_def_use_mgr()->GetDef(15); EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(phi1)); - ir::Instruction* phi2 = context->get_def_use_mgr()->GetDef(18); + Instruction* phi2 = context->get_def_use_mgr()->GetDef(18); EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(phi2)); EXPECT_NE(vtable.GetValueNumber(phi1), vtable.GetValueNumber(phi2)); } -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/vector_dce_test.cpp b/3rdparty/spirv-tools/test/opt/vector_dce_test.cpp new file mode 100644 index 000000000..d1af0de19 --- /dev/null +++ b/3rdparty/spirv-tools/test/opt/vector_dce_test.cpp @@ -0,0 +1,1158 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using VectorDCETest = PassTest<::testing::Test>; + +TEST_F(VectorDCETest, InsertAfterInsertElim) { + // With two insertions to the same offset, the first is dead. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in float In0; + // layout (location=1) in float In1; + // layout (location=2) in vec2 In2; + // layout (location=0) out vec4 OutColor; + // + // void main() + // { + // vec2 v = In2; + // v.x = In0 + In1; // dead + // v.x = 0.0; + // OutColor = v.xyxy; + // } + + const std::string before_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In2 %In0 %In1 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In2 "In2" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %OutColor "OutColor" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpMemberName %_Globals_ 1 "g_n" +OpName %_ "" +OpDecorate %In2 Location 2 +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %OutColor Location 0 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpMemberDecorate %_Globals_ 1 Offset 4 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%In2 = OpVariable %_ptr_Input_v2float Input +%_ptr_Input_float = OpTypePointer Input %float +%In0 = OpVariable %_ptr_Input_float Input +%In1 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%int = OpTypeInt 32 1 +%_Globals_ = OpTypeStruct %uint %int +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +)"; + + const std::string after_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In2 %In0 %In1 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In2 "In2" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %OutColor "OutColor" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpMemberName %_Globals_ 1 "g_n" +OpName %_ "" +OpDecorate %In2 Location 2 +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %OutColor Location 0 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpMemberDecorate %_Globals_ 1 Offset 4 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%In2 = OpVariable %_ptr_Input_v2float Input +%_ptr_Input_float = OpTypePointer Input %float +%In0 = OpVariable %_ptr_Input_float Input +%In1 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%int = OpTypeInt 32 1 +%_Globals_ = OpTypeStruct %uint %int +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +)"; + + const std::string before = + R"(%main = OpFunction %void None %11 +%25 = OpLabel +%26 = OpLoad %v2float %In2 +%27 = OpLoad %float %In0 +%28 = OpLoad %float %In1 +%29 = OpFAdd %float %27 %28 +%35 = OpCompositeInsert %v2float %29 %26 0 +%37 = OpCompositeInsert %v2float %float_0 %35 0 +%33 = OpVectorShuffle %v4float %37 %37 0 1 0 1 +OpStore %OutColor %33 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%23 = OpLabel +%24 = OpLoad %v2float %In2 +%25 = OpLoad %float %In0 +%26 = OpLoad %float %In1 +%27 = OpFAdd %float %25 %26 +%28 = OpCompositeInsert %v2float %27 %24 0 +%29 = OpCompositeInsert %v2float %float_0 %24 0 +%30 = OpVectorShuffle %v4float %29 %29 0 1 0 1 +OpStore %OutColor %30 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); +} + +TEST_F(VectorDCETest, DeadInsertInChainWithPhi) { + // Dead insert eliminated with phi in insertion chain. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in vec4 In0; + // layout (location=1) in float In1; + // layout (location=2) in float In2; + // layout (location=0) out vec4 OutColor; + // + // layout(std140, binding = 0 ) uniform _Globals_ + // { + // bool g_b; + // }; + // + // void main() + // { + // vec4 v = In0; + // v.z = In1 + In2; + // if (g_b) v.w = 1.0; + // OutColor = vec4(v.x,v.y,0.0,v.w); + // } + + const std::string before_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %In2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpName %_ "" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%_Globals_ = OpTypeStruct %uint +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +)"; + + const std::string after_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %In2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpName %_ "" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%_Globals_ = OpTypeStruct %uint +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +)"; + + const std::string before = + R"(%main = OpFunction %void None %11 +%31 = OpLabel +%32 = OpLoad %v4float %In0 +%33 = OpLoad %float %In1 +%34 = OpLoad %float %In2 +%35 = OpFAdd %float %33 %34 +%51 = OpCompositeInsert %v4float %35 %32 2 +%37 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%38 = OpLoad %uint %37 +%39 = OpINotEqual %bool %38 %uint_0 +OpSelectionMerge %40 None +OpBranchConditional %39 %41 %40 +%41 = OpLabel +%53 = OpCompositeInsert %v4float %float_1 %51 3 +OpBranch %40 +%40 = OpLabel +%60 = OpPhi %v4float %51 %31 %53 %41 +%55 = OpCompositeExtract %float %60 0 +%57 = OpCompositeExtract %float %60 1 +%59 = OpCompositeExtract %float %60 3 +%49 = OpCompositeConstruct %v4float %55 %57 %float_0 %59 +OpStore %OutColor %49 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%27 = OpLabel +%28 = OpLoad %v4float %In0 +%29 = OpLoad %float %In1 +%30 = OpLoad %float %In2 +%31 = OpFAdd %float %29 %30 +%32 = OpCompositeInsert %v4float %31 %28 2 +%33 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%34 = OpLoad %uint %33 +%35 = OpINotEqual %bool %34 %uint_0 +OpSelectionMerge %36 None +OpBranchConditional %35 %37 %36 +%37 = OpLabel +%38 = OpCompositeInsert %v4float %float_1 %28 3 +OpBranch %36 +%36 = OpLabel +%39 = OpPhi %v4float %28 %27 %38 %37 +%40 = OpCompositeExtract %float %39 0 +%41 = OpCompositeExtract %float %39 1 +%42 = OpCompositeExtract %float %39 3 +%43 = OpCompositeConstruct %v4float %40 %41 %float_0 %42 +OpStore %OutColor %43 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); +} + +TEST_F(VectorDCETest, DeadInsertWithScalars) { + // Dead insert which requires two passes to eliminate + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in vec4 In0; + // layout (location=1) in float In1; + // layout (location=2) in float In2; + // layout (location=0) out vec4 OutColor; + // + // layout(std140, binding = 0 ) uniform _Globals_ + // { + // bool g_b; + // bool g_b2; + // }; + // + // void main() + // { + // vec4 v1, v2; + // v1 = In0; + // v1.y = In1 + In2; // dead, second pass + // if (g_b) v1.x = 1.0; + // v2.x = v1.x; + // v2.y = v1.y; // dead, first pass + // if (g_b2) v2.x = 0.0; + // OutColor = vec4(v2.x,v2.x,0.0,1.0); + // } + + const std::string before_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %In2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpMemberName %_Globals_ 1 "g_b2" +OpName %_ "" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpMemberDecorate %_Globals_ 1 Offset 4 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_Globals_ = OpTypeStruct %uint %uint +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%float_1 = OpConstant %float 1 +%int_1 = OpConstant %int 1 +%float_0 = OpConstant %float 0 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%27 = OpUndef %v4float +)"; + + const std::string after_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %In2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpMemberName %_Globals_ 1 "g_b2" +OpName %_ "" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpMemberDecorate %_Globals_ 1 Offset 4 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_Globals_ = OpTypeStruct %uint %uint +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%float_1 = OpConstant %float 1 +%int_1 = OpConstant %int 1 +%float_0 = OpConstant %float 0 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%27 = OpUndef %v4float +)"; + + const std::string before = + R"(%main = OpFunction %void None %10 +%28 = OpLabel +%29 = OpLoad %v4float %In0 +%30 = OpLoad %float %In1 +%31 = OpLoad %float %In2 +%32 = OpFAdd %float %30 %31 +%33 = OpCompositeInsert %v4float %32 %29 1 +%34 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%35 = OpLoad %uint %34 +%36 = OpINotEqual %bool %35 %uint_0 +OpSelectionMerge %37 None +OpBranchConditional %36 %38 %37 +%38 = OpLabel +%39 = OpCompositeInsert %v4float %float_1 %33 0 +OpBranch %37 +%37 = OpLabel +%40 = OpPhi %v4float %33 %28 %39 %38 +%41 = OpCompositeExtract %float %40 0 +%42 = OpCompositeInsert %v4float %41 %27 0 +%43 = OpCompositeExtract %float %40 1 +%44 = OpCompositeInsert %v4float %43 %42 1 +%45 = OpAccessChain %_ptr_Uniform_uint %_ %int_1 +%46 = OpLoad %uint %45 +%47 = OpINotEqual %bool %46 %uint_0 +OpSelectionMerge %48 None +OpBranchConditional %47 %49 %48 +%49 = OpLabel +%50 = OpCompositeInsert %v4float %float_0 %44 0 +OpBranch %48 +%48 = OpLabel +%51 = OpPhi %v4float %44 %37 %50 %49 +%52 = OpCompositeExtract %float %51 0 +%53 = OpCompositeExtract %float %51 0 +%54 = OpCompositeConstruct %v4float %52 %53 %float_0 %float_1 +OpStore %OutColor %54 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%28 = OpLabel +%29 = OpLoad %v4float %In0 +%34 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%35 = OpLoad %uint %34 +%36 = OpINotEqual %bool %35 %uint_0 +OpSelectionMerge %37 None +OpBranchConditional %36 %38 %37 +%38 = OpLabel +%39 = OpCompositeInsert %v4float %float_1 %29 0 +OpBranch %37 +%37 = OpLabel +%40 = OpPhi %v4float %29 %28 %39 %38 +%41 = OpCompositeExtract %float %40 0 +%42 = OpCompositeInsert %v4float %41 %27 0 +%45 = OpAccessChain %_ptr_Uniform_uint %_ %int_1 +%46 = OpLoad %uint %45 +%47 = OpINotEqual %bool %46 %uint_0 +OpSelectionMerge %48 None +OpBranchConditional %47 %49 %48 +%49 = OpLabel +%50 = OpCompositeInsert %v4float %float_0 %42 0 +OpBranch %48 +%48 = OpLabel +%51 = OpPhi %v4float %42 %37 %50 %49 +%52 = OpCompositeExtract %float %51 0 +%53 = OpCompositeExtract %float %51 0 +%54 = OpCompositeConstruct %v4float %52 %53 %float_0 %float_1 +OpStore %OutColor %54 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); +} + +TEST_F(VectorDCETest, InsertObjectLive) { + // Make sure that the object being inserted in an OpCompositeInsert + // is not removed when it is uses later on. + const std::string before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %10 +%28 = OpLabel +%29 = OpLoad %v4float %In0 +%30 = OpLoad %float %In1 +%33 = OpCompositeInsert %v4float %30 %29 1 +OpStore %OutColor %33 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, before, true, true); +} + +#ifdef SPIRV_EFFCEE +TEST_F(VectorDCETest, DeadInsertInCycle) { + // Dead insert in chain with cycle. Demonstrates analysis can handle + // cycles in chains going through scalars intermediate values. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in vec4 In0; + // layout (location=1) in float In1; + // layout (location=2) in float In2; + // layout (location=0) out vec4 OutColor; + // + // layout(std140, binding = 0 ) uniform _Globals_ + // { + // int g_n ; + // }; + // + // void main() + // { + // vec2 v = vec2(0.0, 1.0); + // for (int i = 0; i < g_n; i++) { + // v.x = v.x + 1; + // v.y = v.y * 0.9; // dead + // } + // OutColor = vec4(v.x); + // } + + const std::string assembly = + R"( +; CHECK: [[init_val:%\w+]] = OpConstantComposite %v2float %float_0 %float_1 +; CHECK: [[undef:%\w+]] = OpUndef %v2float +; CHECK: OpFunction +; CHECK: [[entry_lab:%\w+]] = OpLabel +; CHECK: [[loop_header:%\w+]] = OpLabel +; CHECK: OpPhi %v2float [[init_val]] [[entry_lab]] [[x_insert:%\w+]] {{%\w+}} +; CHECK: [[x_insert:%\w+]] = OpCompositeInsert %v2float %43 [[undef]] 0 +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %OutColor %In0 %In1 %In2 +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_n" +OpName %_ "" +OpName %OutColor "OutColor" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpMemberDecorate %_Globals_ 0 Offset 0 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%float_0 = OpConstant %float 0 +%float_1 = OpConstant %float 1 +%16 = OpConstantComposite %v2float %float_0 %float_1 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%_Globals_ = OpTypeStruct %int +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%_ptr_Uniform_int = OpTypePointer Uniform %int +%bool = OpTypeBool +%float_0_75 = OpConstant %float 0.75 +%int_1 = OpConstant %int 1 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%main = OpFunction %void None %10 +%29 = OpLabel +OpBranch %30 +%30 = OpLabel +%31 = OpPhi %v2float %16 %29 %32 %33 +%34 = OpPhi %int %int_0 %29 %35 %33 +OpLoopMerge %36 %33 None +OpBranch %37 +%37 = OpLabel +%38 = OpAccessChain %_ptr_Uniform_int %_ %int_0 +%39 = OpLoad %int %38 +%40 = OpSLessThan %bool %34 %39 +OpBranchConditional %40 %41 %36 +%41 = OpLabel +%42 = OpCompositeExtract %float %31 0 +%43 = OpFAdd %float %42 %float_1 +%44 = OpCompositeInsert %v2float %43 %31 0 +%45 = OpCompositeExtract %float %44 1 +%46 = OpFMul %float %45 %float_0_75 +%32 = OpCompositeInsert %v2float %46 %44 1 +OpBranch %33 +%33 = OpLabel +%35 = OpIAdd %int %34 %int_1 +OpBranch %30 +%36 = OpLabel +%47 = OpCompositeExtract %float %31 0 +%48 = OpCompositeConstruct %v4float %47 %47 %47 %47 +OpStore %OutColor %48 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(assembly, true); +} + +TEST_F(VectorDCETest, DeadLoadFeedingCompositeConstruct) { + // Detach the loads feeding the CompositeConstruct for the unused elements. + // TODO: Implement the rewrite for CompositeConstruct. + + const std::string assembly = + R"( +; CHECK: [[undef:%\w+]] = OpUndef %float +; CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_Input_float %In0 %uint_2 +; CHECK: [[load:%\w+]] = OpLoad %float [[ac]] +; CHECK: OpCompositeConstruct %v3float [[load]] [[undef]] [[undef]] +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" +OpSourceExtension "GL_GOOGLE_include_directive" +OpName %main "main" +OpName %In0 "In0" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%v3float = OpTypeVector %float 3 +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_20 = OpConstant %int 20 +%bool = OpTypeBool +%float_1 = OpConstant %float 1 +%int_1 = OpConstant %int 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%23 = OpUndef %v3float +%main = OpFunction %void None %6 +%24 = OpLabel +%25 = OpAccessChain %_ptr_Input_float %In0 %uint_0 +%26 = OpLoad %float %25 +%27 = OpAccessChain %_ptr_Input_float %In0 %uint_1 +%28 = OpLoad %float %27 +%29 = OpAccessChain %_ptr_Input_float %In0 %uint_2 +%30 = OpLoad %float %29 +%31 = OpCompositeConstruct %v3float %30 %28 %26 +OpBranch %32 +%32 = OpLabel +%33 = OpPhi %v3float %31 %24 %34 %35 +%36 = OpPhi %int %int_0 %24 %37 %35 +OpLoopMerge %38 %35 None +OpBranch %39 +%39 = OpLabel +%40 = OpSLessThan %bool %36 %int_20 +OpBranchConditional %40 %41 %38 +%41 = OpLabel +%42 = OpCompositeExtract %float %33 0 +%43 = OpFAdd %float %42 %float_1 +%34 = OpCompositeInsert %v3float %43 %33 0 +OpBranch %35 +%35 = OpLabel +%37 = OpIAdd %int %36 %int_1 +OpBranch %32 +%38 = OpLabel +%44 = OpCompositeExtract %float %33 0 +%45 = OpCompositeConstruct %v4float %44 %44 %44 %44 +OpStore %OutColor %45 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(assembly, true); +} + +TEST_F(VectorDCETest, DeadLoadFeedingVectorShuffle) { + // Detach the loads feeding the CompositeConstruct for the unused elements. + // TODO: Implement the rewrite for CompositeConstruct. + + const std::string assembly = + R"( +; MemPass Type2Undef does not reuse and already existing undef. +; CHECK: {{%\w+}} = OpUndef %v3float +; CHECK: [[undef:%\w+]] = OpUndef %v3float +; CHECK: OpFunction +; CHECK: OpVectorShuffle %v3float {{%\w+}} [[undef]] 0 4 5 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %In0 %OutColor + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %In0 "In0" + OpName %OutColor "OutColor" + OpDecorate %In0 Location 0 + OpDecorate %OutColor Location 0 + %void = OpTypeVoid + %6 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float + %In0 = OpVariable %_ptr_Input_v4float Input + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float + %uint_1 = OpConstant %uint 1 + %uint_2 = OpConstant %uint 2 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %int_20 = OpConstant %int 20 + %bool = OpTypeBool + %float_1 = OpConstant %float 1 + %vec_const = OpConstantComposite %v3float %float_1 %float_1 %float_1 + %int_1 = OpConstant %int 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %OutColor = OpVariable %_ptr_Output_v4float Output + %23 = OpUndef %v3float + %main = OpFunction %void None %6 + %24 = OpLabel + %25 = OpAccessChain %_ptr_Input_float %In0 %uint_0 + %26 = OpLoad %float %25 + %27 = OpAccessChain %_ptr_Input_float %In0 %uint_1 + %28 = OpLoad %float %27 + %29 = OpAccessChain %_ptr_Input_float %In0 %uint_2 + %30 = OpLoad %float %29 + %31 = OpCompositeConstruct %v3float %30 %28 %26 + %sh = OpVectorShuffle %v3float %vec_const %31 0 4 5 + OpBranch %32 + %32 = OpLabel + %33 = OpPhi %v3float %sh %24 %34 %35 + %36 = OpPhi %int %int_0 %24 %37 %35 + OpLoopMerge %38 %35 None + OpBranch %39 + %39 = OpLabel + %40 = OpSLessThan %bool %36 %int_20 + OpBranchConditional %40 %41 %38 + %41 = OpLabel + %42 = OpCompositeExtract %float %33 0 + %43 = OpFAdd %float %42 %float_1 + %34 = OpCompositeInsert %v3float %43 %33 0 + OpBranch %35 + %35 = OpLabel + %37 = OpIAdd %int %36 %int_1 + OpBranch %32 + %38 = OpLabel + %44 = OpCompositeExtract %float %33 0 + %45 = OpCompositeConstruct %v4float %44 %44 %44 %44 + OpStore %OutColor %45 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(assembly, true); +} + +TEST_F(VectorDCETest, DeadInstThroughShuffle) { + // Dead insert in chain with cycle. Demonstrates analysis can handle + // cycles in chains. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) out vec4 OutColor; + // + // void main() + // { + // vec2 v; + // v.x = 0.0; + // v.y = 0.1; // dead + // for (int i = 0; i < 20; i++) { + // v.x = v.x + 1; + // v = v * 0.9; + // } + // OutColor = vec4(v.x); + // } + + const std::string assembly = + R"( +; CHECK: OpFunction +; CHECK-NOT: OpCompositeInsert %v2float {{%\w+}} 1 +; CHECK: OpFunctionEnd + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %OutColor + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %OutColor "OutColor" + OpDecorate %OutColor Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %float_0 = OpConstant %float 0 +%float_0_100000001 = OpConstant %float 0.100000001 + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %int_20 = OpConstant %int 20 + %bool = OpTypeBool + %float_1 = OpConstant %float 1 +%float_0_899999976 = OpConstant %float 0.899999976 + %int_1 = OpConstant %int 1 + %v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %OutColor = OpVariable %_ptr_Output_v4float Output + %58 = OpUndef %v2float + %main = OpFunction %void None %3 + %5 = OpLabel + %49 = OpCompositeInsert %v2float %float_0 %58 0 + %51 = OpCompositeInsert %v2float %float_0_100000001 %49 1 + OpBranch %22 + %22 = OpLabel + %60 = OpPhi %v2float %51 %5 %38 %25 + %59 = OpPhi %int %int_0 %5 %41 %25 + OpLoopMerge %24 %25 None + OpBranch %26 + %26 = OpLabel + %30 = OpSLessThan %bool %59 %int_20 + OpBranchConditional %30 %23 %24 + %23 = OpLabel + %53 = OpCompositeExtract %float %60 0 + %34 = OpFAdd %float %53 %float_1 + %55 = OpCompositeInsert %v2float %34 %60 0 + %38 = OpVectorTimesScalar %v2float %55 %float_0_899999976 + OpBranch %25 + %25 = OpLabel + %41 = OpIAdd %int %59 %int_1 + OpBranch %22 + %24 = OpLabel + %57 = OpCompositeExtract %float %60 0 + %47 = OpCompositeConstruct %v4float %57 %57 %57 %57 + OpStore %OutColor %47 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(assembly, true); +} + +TEST_F(VectorDCETest, DeadInsertThroughOtherInst) { + // Dead insert in chain with cycle. Demonstrates analysis can handle + // cycles in chains. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) out vec4 OutColor; + // + // void main() + // { + // vec2 v; + // v.x = 0.0; + // v.y = 0.1; // dead + // for (int i = 0; i < 20; i++) { + // v.x = v.x + 1; + // v = v * 0.9; + // } + // OutColor = vec4(v.x); + // } + + const std::string assembly = + R"( +; CHECK: OpFunction +; CHECK-NOT: OpCompositeInsert %v2float {{%\w+}} 1 +; CHECK: OpFunctionEnd + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %OutColor + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %OutColor "OutColor" + OpDecorate %OutColor Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %float_0 = OpConstant %float 0 +%float_0_100000001 = OpConstant %float 0.100000001 + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %int_20 = OpConstant %int 20 + %bool = OpTypeBool + %float_1 = OpConstant %float 1 +%float_0_899999976 = OpConstant %float 0.899999976 + %int_1 = OpConstant %int 1 + %v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %OutColor = OpVariable %_ptr_Output_v4float Output + %58 = OpUndef %v2float + %main = OpFunction %void None %3 + %5 = OpLabel + %49 = OpCompositeInsert %v2float %float_0 %58 0 + %51 = OpCompositeInsert %v2float %float_0_100000001 %49 1 + OpBranch %22 + %22 = OpLabel + %60 = OpPhi %v2float %51 %5 %38 %25 + %59 = OpPhi %int %int_0 %5 %41 %25 + OpLoopMerge %24 %25 None + OpBranch %26 + %26 = OpLabel + %30 = OpSLessThan %bool %59 %int_20 + OpBranchConditional %30 %23 %24 + %23 = OpLabel + %53 = OpCompositeExtract %float %60 0 + %34 = OpFAdd %float %53 %float_1 + %55 = OpCompositeInsert %v2float %34 %60 0 + %38 = OpVectorTimesScalar %v2float %55 %float_0_899999976 + OpBranch %25 + %25 = OpLabel + %41 = OpIAdd %int %59 %int_1 + OpBranch %22 + %24 = OpLabel + %57 = OpCompositeExtract %float %60 0 + %47 = OpCompositeConstruct %v4float %57 %57 %57 %57 + OpStore %OutColor %47 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(assembly, true); +} +#endif + +TEST_F(VectorDCETest, VectorIntoCompositeConstruct) { + const std::string text = R"(OpCapability Linkage +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "EntryPoint_Main" %2 %3 +OpExecutionMode %1 OriginUpperLeft +OpDecorate %2 Location 0 +OpDecorate %_struct_4 Block +OpDecorate %3 Location 0 +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%mat4v4float = OpTypeMatrix %v4float 4 +%_ptr_Function_mat4v4float = OpTypePointer Function %mat4v4float +%v3float = OpTypeVector %float 3 +%_ptr_Function_v3float = OpTypePointer Function %v3float +%_struct_14 = OpTypeStruct %v2float %mat4v4float %v3float %v2float %v4float +%_ptr_Function__struct_14 = OpTypePointer Function %_struct_14 +%void = OpTypeVoid +%int = OpTypeInt 32 1 +%int_2 = OpConstant %int 2 +%int_1 = OpConstant %int 1 +%int_4 = OpConstant %int 4 +%int_0 = OpConstant %int 0 +%int_3 = OpConstant %int 3 +%float_0 = OpConstant %float 0 +%float_1 = OpConstant %float 1 +%_ptr_Input_v2float = OpTypePointer Input %v2float +%2 = OpVariable %_ptr_Input_v2float Input +%_ptr_Output_v2float = OpTypePointer Output %v2float +%_struct_4 = OpTypeStruct %v2float +%_ptr_Output__struct_4 = OpTypePointer Output %_struct_4 +%3 = OpVariable %_ptr_Output__struct_4 Output +%28 = OpTypeFunction %void +%29 = OpConstantComposite %v2float %float_0 %float_0 +%30 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%31 = OpConstantComposite %mat4v4float %30 %30 %30 %30 +%32 = OpConstantComposite %v3float %float_0 %float_0 %float_0 +%1 = OpFunction %void None %28 +%33 = OpLabel +%34 = OpVariable %_ptr_Function_v4float Function +%35 = OpVariable %_ptr_Function__struct_14 Function +%36 = OpAccessChain %_ptr_Function_v2float %35 %int_0 +OpStore %36 %29 +%37 = OpAccessChain %_ptr_Function_mat4v4float %35 %int_1 +OpStore %37 %31 +%38 = OpAccessChain %_ptr_Function_v3float %35 %int_2 +OpStore %38 %32 +%39 = OpAccessChain %_ptr_Function_v2float %35 %int_3 +OpStore %39 %29 +%40 = OpAccessChain %_ptr_Function_v4float %35 %int_4 +OpStore %40 %30 +%41 = OpLoad %v2float %2 +OpStore %36 %41 +%42 = OpLoad %v3float %38 +%43 = OpCompositeConstruct %v4float %42 %float_1 +%44 = OpLoad %mat4v4float %37 +%45 = OpVectorTimesMatrix %v4float %43 %44 +OpStore %34 %45 +OpCopyMemory %40 %34 +OpCopyMemory %36 %39 +%46 = OpAccessChain %_ptr_Output_v2float %3 %int_0 +%47 = OpLoad %v2float %36 +OpStore %46 %47 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(text, text, true, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/opt/workaround1209_test.cpp b/3rdparty/spirv-tools/test/opt/workaround1209_test.cpp index 5708b7e6a..853a01cb2 100644 --- a/3rdparty/spirv-tools/test/opt/workaround1209_test.cpp +++ b/3rdparty/spirv-tools/test/opt/workaround1209_test.cpp @@ -12,20 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "assembly_builder.h" -#include "gmock/gmock.h" -#include "pass_fixture.h" -#include "pass_utils.h" - #include #include #include #include +#include #include -namespace { +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" -using namespace spvtools; +namespace spvtools { +namespace opt { +namespace { using Workaround1209Test = PassTest<::testing::Test>; @@ -119,7 +120,7 @@ TEST_F(Workaround1209Test, RemoveOpUnreachableInLoop) { OpReturn OpFunctionEnd)"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(Workaround1209Test, RemoveOpUnreachableInNestedLoop) { @@ -219,7 +220,7 @@ TEST_F(Workaround1209Test, RemoveOpUnreachableInNestedLoop) { OpReturn OpFunctionEnd)"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(Workaround1209Test, RemoveOpUnreachableInAdjacentLoops) { @@ -333,7 +334,7 @@ TEST_F(Workaround1209Test, RemoveOpUnreachableInAdjacentLoops) { OpReturn OpFunctionEnd)"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } TEST_F(Workaround1209Test, LeaveUnreachableNotInLoop) { @@ -415,7 +416,10 @@ TEST_F(Workaround1209Test, LeaveUnreachableNotInLoop) { OpUnreachable OpFunctionEnd)"; - SinglePassRunAndMatch(text, false); + SinglePassRunAndMatch(text, false); } #endif -} // anonymous namespace + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/parse_number_test.cpp b/3rdparty/spirv-tools/test/parse_number_test.cpp index 9189fbe4e..c99205cf5 100644 --- a/3rdparty/spirv-tools/test/parse_number_test.cpp +++ b/3rdparty/spirv-tools/test/parse_number_test.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include @@ -19,13 +20,10 @@ #include "source/util/parse_number.h" #include "spirv-tools/libspirv.h" +namespace spvtools { +namespace utils { namespace { -using spvutils::EncodeNumberStatus; -using spvutils::NumberType; -using spvutils::ParseAndEncodeFloatingPointNumber; -using spvutils::ParseAndEncodeIntegerNumber; -using spvutils::ParseAndEncodeNumber; -using spvutils::ParseNumber; + using testing::Eq; using testing::IsNull; using testing::NotNull; @@ -191,7 +189,7 @@ TEST(ParseFloat, Overflow) { // range values. When it does overflow, the value is set to the // nearest finite value, matching C++11 behavior for operator>> // on floating point. - spvutils::HexFloat> f(0.0f); + HexFloat> f(0.0f); EXPECT_TRUE(ParseNumber("1e38", &f)); EXPECT_EQ(1e38f, f.value().getAsFloat()); @@ -236,7 +234,7 @@ TEST(ParseDouble, Overflow) { // range values. When it does overflow, the value is set to the // nearest finite value, matching C++11 behavior for operator>> // on floating point. - spvutils::HexFloat> f(0.0); + HexFloat> f(0.0); EXPECT_TRUE(ParseNumber("1e38", &f)); EXPECT_EQ(1e38, f.value().getAsFloat()); @@ -256,7 +254,7 @@ TEST(ParseFloat16, Overflow) { // range values. When it does overflow, the value is set to the // nearest finite value, matching C++11 behavior for operator>> // on floating point. - spvutils::HexFloat> f(0); + HexFloat> f(0); EXPECT_FALSE(ParseNumber(nullptr, &f)); EXPECT_TRUE(ParseNumber("-0.0", &f)); @@ -967,4 +965,6 @@ TEST(ParseAndEncodeNumber, Sample) { EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); } -} // anonymous namespace +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/preserve_numeric_ids_test.cpp b/3rdparty/spirv-tools/test/preserve_numeric_ids_test.cpp index 75d730d17..1c3354d55 100644 --- a/3rdparty/spirv-tools/test/preserve_numeric_ids_test.cpp +++ b/3rdparty/spirv-tools/test/preserve_numeric_ids_test.cpp @@ -18,8 +18,9 @@ #include "source/text.h" #include "source/text_handler.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +namespace spvtools { namespace { using spvtest::ScopedContext; @@ -155,3 +156,4 @@ OpFunctionEnd } } // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/software_version_test.cpp b/3rdparty/spirv-tools/test/software_version_test.cpp index e1b350440..80b944a30 100644 --- a/3rdparty/spirv-tools/test/software_version_test.cpp +++ b/3rdparty/spirv-tools/test/software_version_test.cpp @@ -12,19 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" - #include +#include #include "gmock/gmock.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { using ::testing::AnyOf; using ::testing::Eq; using ::testing::Ge; using ::testing::StartsWith; -namespace { - void CheckFormOfHighLevelVersion(const std::string& version) { std::istringstream s(version); char v = 'x'; @@ -62,4 +63,5 @@ TEST(SoftwareVersion, DetailedIsCorrectForm) { // We don't actually care about what comes after the version number. } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/stats/CMakeLists.txt b/3rdparty/spirv-tools/test/stats/CMakeLists.txt index 20f05fd0f..3e4a0742f 100644 --- a/3rdparty/spirv-tools/test/stats/CMakeLists.txt +++ b/3rdparty/spirv-tools/test/stats/CMakeLists.txt @@ -19,6 +19,7 @@ set(VAL_TEST_COMMON_SRCS add_spvtools_unittest(TARGET stats_aggregate SRCS stats_aggregate_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../../tools/stats/spirv_stats.cpp ${VAL_TEST_COMMON_SRCS} LIBS ${SPIRV_TOOLS} ) diff --git a/3rdparty/spirv-tools/test/stats/stats_aggregate_test.cpp b/3rdparty/spirv-tools/test/stats/stats_aggregate_test.cpp index 11c71aa6c..505fe2d6d 100644 --- a/3rdparty/spirv-tools/test/stats/stats_aggregate_test.cpp +++ b/3rdparty/spirv-tools/test/stats/stats_aggregate_test.cpp @@ -15,15 +15,16 @@ // Tests for unique type declaration rules validator. #include +#include -#include "source/spirv_stats.h" -#include "test_fixture.h" -#include "unit_spirv.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +#include "tools/stats/spirv_stats.h" +namespace spvtools { +namespace stats { namespace { -using libspirv::SetContextMessageConsumer; -using libspirv::SpirvStats; using spvtest::ScopedContext; void DiagnosticsMessageHandler(spv_message_level_t level, const char*, @@ -47,7 +48,7 @@ void DiagnosticsMessageHandler(spv_message_level_t level, const char*, } } -// Calls libspirv::AggregateStats for binary compiled from |code|. +// Calls AggregateStats for binary compiled from |code|. void CompileAndAggregateStats(const std::string& code, SpirvStats* stats, spv_target_env env = SPV_ENV_UNIVERSAL_1_1) { ScopedContext ctx(env); @@ -432,53 +433,6 @@ OpMemoryModel Logical GLSL450 EXPECT_EQ(1u, stats.s64_constant_hist.at(-64)); } -TEST(AggregateStats, IdDescriptor) { - const std::string code1 = R"( -OpCapability Addresses -OpCapability Kernel -OpCapability GenericPointer -OpCapability Linkage -OpMemoryModel Physical32 OpenCL -%u32 = OpTypeInt 32 0 -%f32 = OpTypeFloat 32 -%1 = OpConstant %f32 1 -%2 = OpConstant %f32 1 -%3 = OpConstant %u32 32 -)"; - - const std::string code2 = R"( -OpCapability Shader -OpCapability Linkage -OpMemoryModel Logical GLSL450 -%f32 = OpTypeFloat 32 -%u32 = OpTypeInt 32 0 -%1 = OpConstant %f32 1 -%2 = OpConstant %f32 3 -%3 = OpConstant %u32 32 -)"; - - const uint32_t kF32 = 1951208733; - const uint32_t kU32 = 2430404313; - const uint32_t kF32_1 = 296981500; - const uint32_t kF32_3 = 1450415100; - const uint32_t kU32_32 = 827246872; - - SpirvStats stats; - - CompileAndAggregateStats(code1, &stats); - - { - const std::unordered_map expected = { - {kF32, 3}, {kU32, 2}, {kF32_1, 2}, {kU32_32, 1}}; - EXPECT_EQ(expected, stats.id_descriptor_hist); - } - - CompileAndAggregateStats(code2, &stats); - { - const std::unordered_map expected = { - {kF32, 6}, {kU32, 4}, {kF32_1, 3}, {kF32_3, 1}, {kU32_32, 2}}; - EXPECT_EQ(expected, stats.id_descriptor_hist); - } -} - } // namespace +} // namespace stats +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/stats/stats_analyzer_test.cpp b/3rdparty/spirv-tools/test/stats/stats_analyzer_test.cpp index 9608af8db..3764c5bdd 100644 --- a/3rdparty/spirv-tools/test/stats/stats_analyzer_test.cpp +++ b/3rdparty/spirv-tools/test/stats/stats_analyzer_test.cpp @@ -17,14 +17,14 @@ #include #include -#include "latest_version_spirv_header.h" -#include "test_fixture.h" +#include "source/latest_version_spirv_header.h" +#include "test/test_fixture.h" #include "tools/stats/stats_analyzer.h" +namespace spvtools { +namespace stats { namespace { -using libspirv::SpirvStats; - // Fills |stats| with some synthetic header stats, as if aggregated from 100 // modules (100 used for simpler percentage evaluation). void FillDefaultStats(SpirvStats* stats) { @@ -170,3 +170,5 @@ TEST(StatsAnalyzer, OpcodeMarkov) { } } // namespace +} // namespace stats +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/string_utils_test.cpp b/3rdparty/spirv-tools/test/string_utils_test.cpp index ea1bfba62..58514158f 100644 --- a/3rdparty/spirv-tools/test/string_utils_test.cpp +++ b/3rdparty/spirv-tools/test/string_utils_test.cpp @@ -18,11 +18,10 @@ #include "source/util/string_utils.h" #include "spirv-tools/libspirv.h" +namespace spvtools { +namespace utils { namespace { -using ::spvutils::CardinalToOrdinal; -using ::spvutils::ToString; - TEST(ToString, Int) { EXPECT_EQ("0", ToString(0)); EXPECT_EQ("1000", ToString(1000)); @@ -187,4 +186,6 @@ TEST(CardinalToOrdinal, Test) { EXPECT_EQ("1225th", CardinalToOrdinal(1225)); } -} // anonymous namespace +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/target_env_test.cpp b/3rdparty/spirv-tools/test/target_env_test.cpp index 0c7062158..f9624646d 100644 --- a/3rdparty/spirv-tools/test/target_env_test.cpp +++ b/3rdparty/spirv-tools/test/target_env_test.cpp @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - -#include "unit_spirv.h" +#include +#include "gmock/gmock.h" #include "source/spirv_target_env.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using ::testing::AnyOf; @@ -91,6 +92,7 @@ INSTANTIATE_TEST_CASE_P( {"opencl2.0embedded", true, SPV_ENV_OPENCL_EMBEDDED_2_0}, {"opencl2.1embedded", true, SPV_ENV_OPENCL_EMBEDDED_2_1}, {"opencl2.2embedded", true, SPV_ENV_OPENCL_EMBEDDED_2_2}, + {"webgpu0", true, SPV_ENV_WEBGPU_0}, {"opencl2.3", false, SPV_ENV_UNIVERSAL_1_0}, {"opencl3.0", false, SPV_ENV_UNIVERSAL_1_0}, {"vulkan1.2", false, SPV_ENV_UNIVERSAL_1_0}, @@ -100,4 +102,5 @@ INSTANTIATE_TEST_CASE_P( {"abc", false, SPV_ENV_UNIVERSAL_1_0}, })); -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/test_fixture.h b/3rdparty/spirv-tools/test/test_fixture.h index e2948a899..e85015c94 100644 --- a/3rdparty/spirv-tools/test/test_fixture.h +++ b/3rdparty/spirv-tools/test/test_fixture.h @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_TEST_TEST_FIXTURE_H_ -#define LIBSPIRV_TEST_TEST_FIXTURE_H_ +#ifndef TEST_TEST_FIXTURE_H_ +#define TEST_TEST_FIXTURE_H_ -#include "unit_spirv.h" +#include +#include + +#include "test/unit_spirv.h" namespace spvtest { @@ -179,4 +182,4 @@ using TextToBinaryTest = TextToBinaryTestBase<::testing::Test>; using RoundTripTest = spvtest::TextToBinaryTestBase<::testing::TestWithParam>; -#endif // LIBSPIRV_TEST_TEST_FIXTURE_H_ +#endif // TEST_TEST_FIXTURE_H_ diff --git a/3rdparty/spirv-tools/test/text_advance_test.cpp b/3rdparty/spirv-tools/test/text_advance_test.cpp index ffc8dbb21..9de77a836 100644 --- a/3rdparty/spirv-tools/test/text_advance_test.cpp +++ b/3rdparty/spirv-tools/test/text_advance_test.cpp @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include +#include "test/unit_spirv.h" + +namespace spvtools { namespace { -using libspirv::AssemblyContext; using spvtest::AutoText; TEST(TextAdvance, LeadingNewLines) { @@ -128,4 +130,5 @@ TEST(TextAdvance, SkipOverCRLFs) { EXPECT_EQ(2u, pos.line); EXPECT_EQ(4u, pos.index); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_destroy_test.cpp b/3rdparty/spirv-tools/test/text_destroy_test.cpp index c956d72c6..4c2837ba6 100644 --- a/3rdparty/spirv-tools/test/text_destroy_test.cpp +++ b/3rdparty/spirv-tools/test/text_destroy_test.cpp @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { TEST(TextDestroy, DestroyNull) { spvBinaryDestroy(nullptr); } @@ -70,4 +71,5 @@ TEST(TextDestroy, Default) { spvContextDestroy(context); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_literal_test.cpp b/3rdparty/spirv-tools/test/text_literal_test.cpp index dc75a8e11..702808931 100644 --- a/3rdparty/spirv-tools/test/text_literal_test.cpp +++ b/3rdparty/spirv-tools/test/text_literal_test.cpp @@ -12,16 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include +#include +#include #include "gmock/gmock.h" -#include "message.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" -#include +namespace spvtools { +namespace { using ::testing::Eq; -namespace { TEST(TextLiteral, GoodI32) { spv_literal_t l; @@ -166,31 +168,30 @@ using IntegerTest = spvtest::TextToBinaryTestBase<::testing::TestWithParam>; std::vector successfulEncode(const TextLiteralCase& test, - libspirv::IdTypeClass type) { + IdTypeClass type) { spv_instruction_t inst; std::string message; auto capture_message = [&message](spv_message_level_t, const char*, const spv_position_t&, const char* m) { message = m; }; - libspirv::IdType expected_type{test.bitwidth, test.is_signed, type}; + IdType expected_type{test.bitwidth, test.is_signed, type}; EXPECT_EQ(SPV_SUCCESS, - libspirv::AssemblyContext(nullptr, capture_message) + AssemblyContext(nullptr, capture_message) .binaryEncodeNumericLiteral(test.text, SPV_ERROR_INVALID_TEXT, expected_type, &inst)) << message; return inst.words; } -std::string failedEncode(const TextLiteralCase& test, - libspirv::IdTypeClass type) { +std::string failedEncode(const TextLiteralCase& test, IdTypeClass type) { spv_instruction_t inst; std::string message; auto capture_message = [&message](spv_message_level_t, const char*, const spv_position_t&, const char* m) { message = m; }; - libspirv::IdType expected_type{test.bitwidth, test.is_signed, type}; + IdType expected_type{test.bitwidth, test.is_signed, type}; EXPECT_EQ(SPV_ERROR_INVALID_TEXT, - libspirv::AssemblyContext(nullptr, capture_message) + AssemblyContext(nullptr, capture_message) .binaryEncodeNumericLiteral(test.text, SPV_ERROR_INVALID_TEXT, expected_type, &inst)); return message; @@ -198,17 +199,15 @@ std::string failedEncode(const TextLiteralCase& test, TEST_P(IntegerTest, IntegerBounds) { if (GetParam().success) { - EXPECT_THAT( - successfulEncode(GetParam(), libspirv::IdTypeClass::kScalarIntegerType), - Eq(GetParam().expected_values)); + EXPECT_THAT(successfulEncode(GetParam(), IdTypeClass::kScalarIntegerType), + Eq(GetParam().expected_values)); } else { std::stringstream ss; ss << "Integer " << GetParam().text << " does not fit in a " << GetParam().bitwidth << "-bit " << (GetParam().is_signed ? "signed" : "unsigned") << " integer"; - EXPECT_THAT( - failedEncode(GetParam(), libspirv::IdTypeClass::kScalarIntegerType), - Eq(ss.str())); + EXPECT_THAT(failedEncode(GetParam(), IdTypeClass::kScalarIntegerType), + Eq(ss.str())); } } @@ -286,9 +285,8 @@ using IntegerLeadingMinusTest = TEST_P(IntegerLeadingMinusTest, CantHaveLeadingMinusOnUnsigned) { EXPECT_FALSE(GetParam().success); - EXPECT_THAT( - failedEncode(GetParam(), libspirv::IdTypeClass::kScalarIntegerType), - Eq("Cannot put a negative number in an unsigned literal")); + EXPECT_THAT(failedEncode(GetParam(), IdTypeClass::kScalarIntegerType), + Eq("Cannot put a negative number in an unsigned literal")); } // clang-format off @@ -380,14 +378,14 @@ TEST(OverflowIntegerParse, Decimal) { std::string expected_message0 = "Invalid signed integer literal: " + signed_input; EXPECT_THAT(failedEncode(Make_Bad_Signed(64, signed_input.c_str()), - libspirv::IdTypeClass::kScalarIntegerType), + IdTypeClass::kScalarIntegerType), Eq(expected_message0)); std::string unsigned_input = "18446744073709551616"; std::string expected_message1 = "Invalid unsigned integer literal: " + unsigned_input; EXPECT_THAT(failedEncode(Make_Bad_Unsigned(64, unsigned_input.c_str()), - libspirv::IdTypeClass::kScalarIntegerType), + IdTypeClass::kScalarIntegerType), Eq(expected_message1)); // TODO(dneto): When the given number doesn't have a leading sign, @@ -395,7 +393,7 @@ TEST(OverflowIntegerParse, Decimal) { // asked for a signed number. This is kind of weird, but it's an // artefact of how we do the parsing. EXPECT_THAT(failedEncode(Make_Bad_Signed(64, unsigned_input.c_str()), - libspirv::IdTypeClass::kScalarIntegerType), + IdTypeClass::kScalarIntegerType), Eq(expected_message1)); } @@ -403,11 +401,12 @@ TEST(OverflowIntegerParse, Hex) { std::string input = "0x10000000000000000"; std::string expected_message = "Invalid unsigned integer literal: " + input; EXPECT_THAT(failedEncode(Make_Bad_Signed(64, input.c_str()), - libspirv::IdTypeClass::kScalarIntegerType), + IdTypeClass::kScalarIntegerType), Eq(expected_message)); EXPECT_THAT(failedEncode(Make_Bad_Unsigned(64, input.c_str()), - libspirv::IdTypeClass::kScalarIntegerType), + IdTypeClass::kScalarIntegerType), Eq(expected_message)); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_start_new_inst_test.cpp b/3rdparty/spirv-tools/test/text_start_new_inst_test.cpp index 6c3e55418..ff35ac84c 100644 --- a/3rdparty/spirv-tools/test/text_start_new_inst_test.cpp +++ b/3rdparty/spirv-tools/test/text_start_new_inst_test.cpp @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" - #include +#include "test/unit_spirv.h" + +namespace spvtools { namespace { -using libspirv::AssemblyContext; using spvtest::AutoText; TEST(TextStartsWithOp, YesAtStart) { @@ -71,4 +71,5 @@ TEST(TextStartsWithOp, NoForNearlyValueGeneration) { EXPECT_FALSE(AssemblyContext(AutoText("%foo"), nullptr).isStartOfNewInst()); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.annotation_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.annotation_test.cpp index 94cc8f07b..7aec90555 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.annotation_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.annotation_test.cpp @@ -15,22 +15,22 @@ // Assembler tests for instructions in the "Annotation" section of the // SPIR-V spec. -#include "unit_spirv.h" - #include +#include #include +#include #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::EnumCase; using spvtest::MakeInstruction; using spvtest::MakeVector; using spvtest::TextToBinaryTest; -using std::get; -using std::tuple; using ::testing::Combine; using ::testing::Eq; using ::testing::Values; @@ -38,23 +38,25 @@ using ::testing::ValuesIn; // Test OpDecorate -using OpDecorateSimpleTest = spvtest::TextToBinaryTestBase< - ::testing::TestWithParam>>>; +using OpDecorateSimpleTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam< + std::tuple>>>; TEST_P(OpDecorateSimpleTest, AnySimpleDecoration) { // This string should assemble, but should not validate. std::stringstream input; - input << "OpDecorate %1 " << get<1>(GetParam()).name(); - for (auto operand : get<1>(GetParam()).operands()) input << " " << operand; + input << "OpDecorate %1 " << std::get<1>(GetParam()).name(); + for (auto operand : std::get<1>(GetParam()).operands()) + input << " " << operand; input << std::endl; - EXPECT_THAT(CompiledInstructions(input.str(), get<0>(GetParam())), + EXPECT_THAT(CompiledInstructions(input.str(), std::get<0>(GetParam())), Eq(MakeInstruction(SpvOpDecorate, - {1, uint32_t(get<1>(GetParam()).value())}, - get<1>(GetParam()).operands()))); + {1, uint32_t(std::get<1>(GetParam()).value())}, + std::get<1>(GetParam()).operands()))); // Also check disassembly. EXPECT_THAT( EncodeAndDecodeSuccessfully(input.str(), SPV_BINARY_TO_TEXT_OPTION_NONE, - get<0>(GetParam())), + std::get<0>(GetParam())), Eq(input.str())); } @@ -397,23 +399,26 @@ TEST_F(TextToBinaryTest, GroupMemberDecorateInvalidSecondTargetMemberNumber) { // Test OpMemberDecorate -using OpMemberDecorateSimpleTest = spvtest::TextToBinaryTestBase< - ::testing::TestWithParam>>>; +using OpMemberDecorateSimpleTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam< + std::tuple>>>; TEST_P(OpMemberDecorateSimpleTest, AnySimpleDecoration) { // This string should assemble, but should not validate. std::stringstream input; - input << "OpMemberDecorate %1 42 " << get<1>(GetParam()).name(); - for (auto operand : get<1>(GetParam()).operands()) input << " " << operand; + input << "OpMemberDecorate %1 42 " << std::get<1>(GetParam()).name(); + for (auto operand : std::get<1>(GetParam()).operands()) + input << " " << operand; input << std::endl; - EXPECT_THAT(CompiledInstructions(input.str(), get<0>(GetParam())), - Eq(MakeInstruction(SpvOpMemberDecorate, - {1, 42, uint32_t(get<1>(GetParam()).value())}, - get<1>(GetParam()).operands()))); + EXPECT_THAT( + CompiledInstructions(input.str(), std::get<0>(GetParam())), + Eq(MakeInstruction(SpvOpMemberDecorate, + {1, 42, uint32_t(std::get<1>(GetParam()).value())}, + std::get<1>(GetParam()).operands()))); // Also check disassembly. EXPECT_THAT( EncodeAndDecodeSuccessfully(input.str(), SPV_BINARY_TO_TEXT_OPTION_NONE, - get<0>(GetParam())), + std::get<0>(GetParam())), Eq(input.str())); } @@ -501,4 +506,5 @@ TEST_F(OpMemberDecorateSimpleTest, ExtraOperandsOnDecorationExpectingTwo) { // TODO(dneto): OpDecorationGroup // TODO(dneto): OpGroupDecorate -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.barrier_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.barrier_test.cpp index 743104298..545d26ff2 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.barrier_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.barrier_test.cpp @@ -15,11 +15,13 @@ // Assembler tests for instructions in the "Barrier Instructions" section // of the SPIR-V spec. -#include "unit_spirv.h" +#include #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::MakeInstruction; @@ -78,10 +80,14 @@ TEST_F(OpMemoryBarrier, BadInvalidMemorySemanticsId) { using NamedMemoryBarrierTest = spvtest::TextToBinaryTest; -TEST_F(NamedMemoryBarrierTest, OpcodeUnrecognizedInV10) { - EXPECT_THAT(CompileFailure("OpMemoryNamedBarrier %bar %scope %semantics", - SPV_ENV_UNIVERSAL_1_0), - Eq("Invalid Opcode name 'OpMemoryNamedBarrier'")); +// OpMemoryNamedBarrier is not in 1.0, but it is enabled by a capability. +// We should be able to assemble it. Validation checks are in another test +// file. +TEST_F(NamedMemoryBarrierTest, OpcodeAssemblesInV10) { + EXPECT_THAT( + CompiledInstructions("OpMemoryNamedBarrier %bar %scope %semantics", + SPV_ENV_UNIVERSAL_1_0), + ElementsAre(spvOpcodeMake(4, SpvOpMemoryNamedBarrier), _, _, _)); } TEST_F(NamedMemoryBarrierTest, ArgumentCount) { @@ -114,9 +120,10 @@ TEST_F(NamedMemoryBarrierTest, ArgumentTypes) { using TypeNamedBarrierTest = spvtest::TextToBinaryTest; -TEST_F(TypeNamedBarrierTest, OpcodeUnrecognizedInV10) { - EXPECT_THAT(CompileFailure("%t = OpTypeNamedBarrier", SPV_ENV_UNIVERSAL_1_0), - Eq("Invalid Opcode name 'OpTypeNamedBarrier'")); +TEST_F(TypeNamedBarrierTest, OpcodeAssemblesInV10) { + EXPECT_THAT( + CompiledInstructions("%t = OpTypeNamedBarrier", SPV_ENV_UNIVERSAL_1_0), + ElementsAre(spvOpcodeMake(2, SpvOpTypeNamedBarrier), _)); } TEST_F(TypeNamedBarrierTest, ArgumentCount) { @@ -134,10 +141,11 @@ TEST_F(TypeNamedBarrierTest, ArgumentCount) { using NamedBarrierInitializeTest = spvtest::TextToBinaryTest; -TEST_F(NamedBarrierInitializeTest, OpcodeUnrecognizedInV10) { - EXPECT_THAT(CompileFailure("%bar = OpNamedBarrierInitialize %type %count", - SPV_ENV_UNIVERSAL_1_0), - Eq("Invalid Opcode name 'OpNamedBarrierInitialize'")); +TEST_F(NamedBarrierInitializeTest, OpcodeAssemblesInV10) { + EXPECT_THAT( + CompiledInstructions("%bar = OpNamedBarrierInitialize %type %count", + SPV_ENV_UNIVERSAL_1_0), + ElementsAre(spvOpcodeMake(4, SpvOpNamedBarrierInitialize), _, _, _)); } TEST_F(NamedBarrierInitializeTest, ArgumentCount) { @@ -158,4 +166,5 @@ TEST_F(NamedBarrierInitializeTest, ArgumentCount) { "found '\"extra\"'.")); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.constant_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.constant_test.cpp index 9aa3977bc..1a24b528f 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.constant_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.constant_test.cpp @@ -15,14 +15,16 @@ // Assembler tests for instructions in the "Group Instrucions" section of the // SPIR-V spec. -#include "unit_spirv.h" - #include #include +#include +#include #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::Concatenate; @@ -574,8 +576,8 @@ INSTANTIATE_TEST_CASE_P( "%1 = OpTypeFloat 32\n%2 = OpConstant %1 0x1.0018p+128\n", // +nan "%1 = OpTypeFloat 32\n%2 = OpConstant %1 0x1.01ep+128\n", // +nan "%1 = OpTypeFloat 32\n%2 = OpConstant %1 0x1.fffffep+128\n", // +nan - "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1p+1024\n", //-inf - "%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1p+1024\n", //+inf + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1p+1024\n", // -inf + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1p+1024\n", // +inf "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1.8p+1024\n", // -nan "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1.0fp+1024\n", // -nan "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1.0000000000001p+1024\n", // -nan @@ -824,4 +826,5 @@ INSTANTIATE_TEST_CASE_P( // TODO(dneto): OpSpecConstantComposite // TODO(dneto): Negative tests for OpSpecConstantOp -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.control_flow_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.control_flow_test.cpp index 56bb408cf..07f110884 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.control_flow_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.control_flow_test.cpp @@ -16,23 +16,21 @@ // SPIR-V spec. #include +#include #include #include -#include "unit_spirv.h" - #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::Concatenate; using spvtest::EnumCase; using spvtest::MakeInstruction; using spvtest::TextToBinaryTest; -using std::get; -using std::ostringstream; -using std::tuple; using ::testing::Combine; using ::testing::Eq; using ::testing::TestWithParam; @@ -79,14 +77,14 @@ TEST_F(OpSelectionMergeTest, WrongSelectionControl) { // Test OpLoopMerge using OpLoopMergeTest = spvtest::TextToBinaryTestBase< - TestWithParam>>>; + TestWithParam>>>; TEST_P(OpLoopMergeTest, AnySingleLoopControlMask) { - const auto ctrl = get<1>(GetParam()); - ostringstream input; + const auto ctrl = std::get<1>(GetParam()); + std::ostringstream input; input << "OpLoopMerge %merge %continue " << ctrl.name(); for (auto num : ctrl.operands()) input << " " << num; - EXPECT_THAT(CompiledInstructions(input.str(), get<0>(GetParam())), + EXPECT_THAT(CompiledInstructions(input.str(), std::get<0>(GetParam())), Eq(MakeInstruction(SpvOpLoopMerge, {1, 2, ctrl.value()}, ctrl.operands()))); } @@ -392,4 +390,5 @@ INSTANTIATE_TEST_CASE_P( // TODO(dneto): OpLifetimeStart // TODO(dneto): OpLifetimeStop -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.debug_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.debug_test.cpp index 93b71eeea..b85650e5e 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.debug_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.debug_test.cpp @@ -15,13 +15,14 @@ // Assembler tests for instructions in the "Debug" section of the // SPIR-V spec. -#include "unit_spirv.h" - #include +#include #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::MakeInstruction; @@ -208,4 +209,6 @@ TEST_P(OpModuleProcessedTest, AnyString) { INSTANTIATE_TEST_CASE_P(TextToBinaryTestDebug, OpModuleProcessedTest, ::testing::Values("", "foo bar this and that"), ); -} // anonymous namespace + +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.device_side_enqueue_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.device_side_enqueue_test.cpp index 782961194..25c100b8e 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.device_side_enqueue_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.device_side_enqueue_test.cpp @@ -15,12 +15,16 @@ // Assembler tests for instructions in the "Device-Side Enqueue Instructions" // section of the SPIR-V spec. -#include "unit_spirv.h" +#include +#include #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { + using spvtest::MakeInstruction; using ::testing::Eq; @@ -104,4 +108,5 @@ TEST_F(OpKernelEnqueueBad, InvalidLastOperand) { // TODO(dneto): OpBuildNDRange // TODO(dneto): OpBuildNDRange -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.extension_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.extension_test.cpp index 690075f2b..0d8d324b8 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.extension_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.extension_test.cpp @@ -15,13 +15,17 @@ // Assembler tests for instructions in the "Extension Instruction" section // of the SPIR-V spec. -#include "unit_spirv.h" +#include +#include +#include #include "gmock/gmock.h" -#include "latest_version_glsl_std_450_header.h" -#include "latest_version_opencl_std_header.h" -#include "test_fixture.h" +#include "source/latest_version_glsl_std_450_header.h" +#include "source/latest_version_opencl_std_header.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::Concatenate; @@ -145,19 +149,19 @@ INSTANTIATE_TEST_CASE_P( MakeInstruction(SpvOpSubgroupBallotKHR, {1, 2, 3})}, {"%2 = OpSubgroupFirstInvocationKHR %1 %3\n", MakeInstruction(SpvOpSubgroupFirstInvocationKHR, {1, 2, 3})}, - {"OpDecorate %1 BuiltIn SubgroupEqMaskKHR\n", + {"OpDecorate %1 BuiltIn SubgroupEqMask\n", MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, SpvBuiltInSubgroupEqMaskKHR})}, - {"OpDecorate %1 BuiltIn SubgroupGeMaskKHR\n", + {"OpDecorate %1 BuiltIn SubgroupGeMask\n", MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, SpvBuiltInSubgroupGeMaskKHR})}, - {"OpDecorate %1 BuiltIn SubgroupGtMaskKHR\n", + {"OpDecorate %1 BuiltIn SubgroupGtMask\n", MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, SpvBuiltInSubgroupGtMaskKHR})}, - {"OpDecorate %1 BuiltIn SubgroupLeMaskKHR\n", + {"OpDecorate %1 BuiltIn SubgroupLeMask\n", MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, SpvBuiltInSubgroupLeMaskKHR})}, - {"OpDecorate %1 BuiltIn SubgroupLtMaskKHR\n", + {"OpDecorate %1 BuiltIn SubgroupLtMask\n", MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, SpvBuiltInSubgroupLtMaskKHR})}, })), ); @@ -316,6 +320,26 @@ INSTANTIATE_TEST_CASE_P( SpvBuiltInDeviceIndex})}, })), ); +// SPV_KHR_8bit_storage + +INSTANTIATE_TEST_CASE_P( + SPV_KHR_8bit_storage, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Combine( + ValuesIn(CommonVulkanEnvs()), + ValuesIn(std::vector{ + {"OpCapability StorageBuffer8BitAccess\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityStorageBuffer8BitAccess})}, + {"OpCapability UniformAndStorageBuffer8BitAccess\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityUniformAndStorageBuffer8BitAccess})}, + {"OpCapability StoragePushConstant8\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityStoragePushConstant8})}, + })), ); + // SPV_KHR_multiview INSTANTIATE_TEST_CASE_P( @@ -523,6 +547,34 @@ INSTANTIATE_TEST_CASE_P( {1, SpvDecorationHlslCounterBufferGOOGLE, 2})}, })), ); +// SPV_NV_viewport_array2 + +INSTANTIATE_TEST_CASE_P( + SPV_NV_viewport_array2, ExtensionRoundTripTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3, + SPV_ENV_VULKAN_1_0, SPV_ENV_VULKAN_1_1), + ValuesIn(std::vector{ + {"OpExtension \"SPV_NV_viewport_array2\"\n", + MakeInstruction(SpvOpExtension, + MakeVector("SPV_NV_viewport_array2"))}, + // The EXT and NV extensions have the same token number for this + // capability. + {"OpCapability ShaderViewportIndexLayerEXT\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityShaderViewportIndexLayerNV})}, + // Check the new capability's token number + {"OpCapability ShaderViewportIndexLayerEXT\n", + MakeInstruction(SpvOpCapability, {5254})}, + // Decorations + {"OpDecorate %1 ViewportRelativeNV\n", + MakeInstruction(SpvOpDecorate, + {1, SpvDecorationViewportRelativeNV})}, + {"OpDecorate %1 BuiltIn ViewportMaskNV\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInViewportMaskNV})}, + })), ); + // SPV_NV_shader_subgroup_partitioned INSTANTIATE_TEST_CASE_P( @@ -657,4 +709,5 @@ INSTANTIATE_TEST_CASE_P( MakeInstruction(SpvOpDecorate, {1, 5300})}, })), ); -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.function_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.function_test.cpp index 5f1b72a04..748461fb1 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.function_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.function_test.cpp @@ -15,11 +15,14 @@ // Assembler tests for instructions in the "Function" section of the // SPIR-V spec. -#include "unit_spirv.h" +#include +#include #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::EnumCase; @@ -74,4 +77,5 @@ TEST_F(OpFunctionControlTest, WrongFunctionControl) { // TODO(dneto): OpFunctionEnd // TODO(dneto): OpFunctionCall -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.group_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.group_test.cpp index 6fb570669..2f4b76d2f 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.group_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.group_test.cpp @@ -15,11 +15,14 @@ // Assembler tests for instructions in the "Group Instrucions" section of the // SPIR-V spec. -#include "unit_spirv.h" +#include +#include #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::EnumCase; @@ -69,4 +72,5 @@ TEST_F(GroupOperationTest, WrongGroupOperation) { // TODO(dneto): OpGroupUMax // TODO(dneto): OpGroupSMax -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.image_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.image_test.cpp index 65f0af379..c1adedf44 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.image_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.image_test.cpp @@ -15,11 +15,14 @@ // Assembler tests for instructions in the "Image Instructions" section of // the SPIR-V spec. -#include "unit_spirv.h" +#include +#include #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::MakeInstruction; @@ -269,4 +272,5 @@ TEST_F(OpImageSparseReadTest, InvalidCoordinateOperand) { // TODO(dneto): OpImageSparseDrefGather // TODO(dneto): OpImageSparseTexelsResident -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.literal_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.literal_test.cpp index 463558188..bcbb63e0d 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.literal_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.literal_test.cpp @@ -14,8 +14,11 @@ // Assembler tests for literal numbers and literal strings. -#include "test_fixture.h" +#include +#include "test/test_fixture.h" + +namespace spvtools { namespace { using spvtest::TextToBinaryTest; @@ -118,4 +121,5 @@ TEST_F(TextToBinaryTest, LiteralStringUTF8LongEncodings) { CompileFailure("OpName %target \"" + bad_1_arg_string + "\"\n")); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.memory_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.memory_test.cpp index 2267b0626..ead08e6fd 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.memory_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.memory_test.cpp @@ -15,13 +15,15 @@ // Assembler tests for instructions in the "Memory Instructions" section of // the SPIR-V spec. -#include "unit_spirv.h" - #include +#include +#include #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::EnumCase; @@ -105,4 +107,5 @@ INSTANTIATE_TEST_CASE_P( // TODO(dneto): OpArrayLength // TODO(dneto): OpGenercPtrMemSemantics -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.misc_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.misc_test.cpp index 3d21b4320..03b1e0914 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.misc_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.misc_test.cpp @@ -15,11 +15,12 @@ // Assembler tests for instructions in the "Miscellaneous" section of the // SPIR-V spec. -#include "unit_spirv.h" +#include "test/unit_spirv.h" #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +namespace spvtools { namespace { using SpirvVector = spvtest::TextToBinaryTest::SpirvVector; @@ -53,4 +54,5 @@ OpXYZ EXPECT_THAT(CompileFailure(assembly), Eq("Invalid Opcode name 'OpXYZ'")); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.mode_setting_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.mode_setting_test.cpp index 972a2f361..ed4fa2fb4 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.mode_setting_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.mode_setting_test.cpp @@ -15,18 +15,20 @@ // Assembler tests for instructions in the "Mode-Setting" section of the // SPIR-V spec. -#include "unit_spirv.h" +#include +#include +#include #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::EnumCase; using spvtest::MakeInstruction; using spvtest::MakeVector; -using std::get; -using std::tuple; using ::testing::Combine; using ::testing::Eq; using ::testing::TestWithParam; @@ -134,17 +136,18 @@ TEST_F(OpEntryPointTest, WrongModel) { // Test OpExecutionMode using OpExecutionModeTest = spvtest::TextToBinaryTestBase< - TestWithParam>>>; + TestWithParam>>>; TEST_P(OpExecutionModeTest, AnyExecutionMode) { // This string should assemble, but should not validate. std::stringstream input; - input << "OpExecutionMode %1 " << get<1>(GetParam()).name(); - for (auto operand : get<1>(GetParam()).operands()) input << " " << operand; - EXPECT_THAT( - CompiledInstructions(input.str(), get<0>(GetParam())), - Eq(MakeInstruction(SpvOpExecutionMode, {1, get<1>(GetParam()).value()}, - get<1>(GetParam()).operands()))); + input << "OpExecutionMode %1 " << std::get<1>(GetParam()).name(); + for (auto operand : std::get<1>(GetParam()).operands()) + input << " " << operand; + EXPECT_THAT(CompiledInstructions(input.str(), std::get<0>(GetParam())), + Eq(MakeInstruction(SpvOpExecutionMode, + {1, std::get<1>(GetParam()).value()}, + std::get<1>(GetParam()).operands()))); } #define CASE(NAME) SpvExecutionMode##NAME, #NAME @@ -295,4 +298,5 @@ TEST_F(TextToBinaryCapability, BadInvalidCapability) { // TODO(dneto): OpExecutionMode -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.pipe_storage_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.pipe_storage_test.cpp index 97dc286f5..f74dbcfdf 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.pipe_storage_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.pipe_storage_test.cpp @@ -13,8 +13,9 @@ // limitations under the License. #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +namespace spvtools { namespace { using ::spvtest::MakeInstruction; @@ -22,9 +23,12 @@ using ::testing::Eq; using OpTypePipeStorageTest = spvtest::TextToBinaryTest; -TEST_F(OpTypePipeStorageTest, OpcodeUnrecognizedInV10) { - EXPECT_THAT(CompileFailure("%res = OpTypePipeStorage", SPV_ENV_UNIVERSAL_1_0), - Eq("Invalid Opcode name 'OpTypePipeStorage'")); +// It can assemble, but should not validate. Validation checks for version +// and capability are in another test file. +TEST_F(OpTypePipeStorageTest, OpcodeAssemblesInV10) { + EXPECT_THAT( + CompiledInstructions("%res = OpTypePipeStorage", SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpTypePipeStorage, {1}))); } TEST_F(OpTypePipeStorageTest, ArgumentCount) { @@ -42,10 +46,10 @@ TEST_F(OpTypePipeStorageTest, ArgumentCount) { using OpConstantPipeStorageTest = spvtest::TextToBinaryTest; -TEST_F(OpConstantPipeStorageTest, OpcodeUnrecognizedInV10) { - EXPECT_THAT(CompileFailure("%1 = OpConstantPipeStorage %2 3 4 5", - SPV_ENV_UNIVERSAL_1_0), - Eq("Invalid Opcode name 'OpConstantPipeStorage'")); +TEST_F(OpConstantPipeStorageTest, OpcodeAssemblesInV10) { + EXPECT_THAT(CompiledInstructions("%1 = OpConstantPipeStorage %2 3 4 5", + SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpConstantPipeStorage, {1, 2, 3, 4, 5}))); } TEST_F(OpConstantPipeStorageTest, ArgumentCount) { @@ -84,10 +88,10 @@ TEST_F(OpConstantPipeStorageTest, ArgumentTypes) { using OpCreatePipeFromPipeStorageTest = spvtest::TextToBinaryTest; -TEST_F(OpCreatePipeFromPipeStorageTest, OpcodeUnrecognizedInV10) { - EXPECT_THAT(CompileFailure("%1 = OpCreatePipeFromPipeStorage %2 %3", - SPV_ENV_UNIVERSAL_1_0), - Eq("Invalid Opcode name 'OpCreatePipeFromPipeStorage'")); +TEST_F(OpCreatePipeFromPipeStorageTest, OpcodeAssemblesInV10) { + EXPECT_THAT(CompiledInstructions("%1 = OpCreatePipeFromPipeStorage %2 %3", + SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpCreatePipeFromPipeStorage, {1, 2, 3}))); } TEST_F(OpCreatePipeFromPipeStorageTest, ArgumentCount) { @@ -118,4 +122,5 @@ TEST_F(OpCreatePipeFromPipeStorageTest, ArgumentTypes) { Eq("Expected id to start with %.")); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.reserved_sampling_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.reserved_sampling_test.cpp index db9a53dc6..42e4e2aeb 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.reserved_sampling_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.reserved_sampling_test.cpp @@ -14,44 +14,50 @@ // Validation tests for illegal instructions -#include "unit_spirv.h" +#include #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { +using ::spvtest::MakeInstruction; using ::testing::Eq; -using ReservedSamplingInstTest = spvtest::TextToBinaryTest; +using ReservedSamplingInstTest = RoundTripTest; TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjImplicitLod) { - const std::string input = "OpImageSparseSampleProjImplicitLod %1 %2 %3\n"; - EXPECT_THAT(CompileFailure(input), - Eq("Invalid Opcode name 'OpImageSparseSampleProjImplicitLod'")); + std::string input = "%2 = OpImageSparseSampleProjImplicitLod %1 %3 %4\n"; + EXPECT_THAT( + CompiledInstructions(input, SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpImageSparseSampleProjImplicitLod, {1, 2, 3, 4}))); } TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjExplicitLod) { - const std::string input = - "OpImageSparseSampleProjExplicitLod %1 %2 %3 Lod %4\n"; - EXPECT_THAT(CompileFailure(input), - Eq("Invalid Opcode name 'OpImageSparseSampleProjExplicitLod'")); + std::string input = + "%2 = OpImageSparseSampleProjExplicitLod %1 %3 %4 Lod %5\n"; + EXPECT_THAT(CompiledInstructions(input, SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpImageSparseSampleProjExplicitLod, + {1, 2, 3, 4, SpvImageOperandsLodMask, 5}))); } TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjDrefImplicitLod) { - const std::string input = - "OpImageSparseSampleProjDrefImplicitLod %1 %2 %3 %4\n"; - EXPECT_THAT( - CompileFailure(input), - Eq("Invalid Opcode name 'OpImageSparseSampleProjDrefImplicitLod'")); + std::string input = + "%2 = OpImageSparseSampleProjDrefImplicitLod %1 %3 %4 %5\n"; + EXPECT_THAT(CompiledInstructions(input, SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpImageSparseSampleProjDrefImplicitLod, + {1, 2, 3, 4, 5}))); } TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjDrefExplicitLod) { - const std::string input = - "OpImageSparseSampleProjDrefExplicitLod %1 %2 %3 %4 Lod %5\n"; - EXPECT_THAT( - CompileFailure(input), - Eq("Invalid Opcode name 'OpImageSparseSampleProjDrefExplicitLod'")); + std::string input = + "%2 = OpImageSparseSampleProjDrefExplicitLod %1 %3 %4 %5 Lod %6\n"; + EXPECT_THAT(CompiledInstructions(input, SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpImageSparseSampleProjDrefExplicitLod, + {1, 2, 3, 4, 5, SpvImageOperandsLodMask, 6}))); } } // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.subgroup_dispatch_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.subgroup_dispatch_test.cpp index 8cc8896ff..967e3c38b 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.subgroup_dispatch_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.subgroup_dispatch_test.cpp @@ -15,25 +15,28 @@ // Assembler tests for instructions in the "Barrier Instructions" section // of the SPIR-V spec. -#include "unit_spirv.h" +#include "test/unit_spirv.h" #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +namespace spvtools { namespace { using ::spvtest::MakeInstruction; -using std::vector; using ::testing::Eq; using OpGetKernelLocalSizeForSubgroupCountTest = spvtest::TextToBinaryTest; -TEST_F(OpGetKernelLocalSizeForSubgroupCountTest, OpcodeUnrecognizedInV10) { +// We should be able to assemble it. Validation checks are in another test +// file. +TEST_F(OpGetKernelLocalSizeForSubgroupCountTest, OpcodeAssemblesInV10) { EXPECT_THAT( - CompileFailure("%res = OpGetKernelLocalSizeForSubgroupCount %type " - "%sgcount %invoke %param %param_size %param_align", - SPV_ENV_UNIVERSAL_1_0), - Eq("Invalid Opcode name 'OpGetKernelLocalSizeForSubgroupCount'")); + CompiledInstructions("%res = OpGetKernelLocalSizeForSubgroupCount %type " + "%sgcount %invoke %param %param_size %param_align", + SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpGetKernelLocalSizeForSubgroupCount, + {1, 2, 3, 4, 5, 6, 7}))); } TEST_F(OpGetKernelLocalSizeForSubgroupCountTest, ArgumentCount) { @@ -75,11 +78,12 @@ TEST_F(OpGetKernelLocalSizeForSubgroupCountTest, ArgumentTypes) { using OpGetKernelMaxNumSubgroupsTest = spvtest::TextToBinaryTest; -TEST_F(OpGetKernelMaxNumSubgroupsTest, OpcodeUnrecognizedInV10) { - EXPECT_THAT(CompileFailure("%res = OpGetKernelLocalSizeForSubgroupCount " - "%type %invoke %param %param_size %param_align", - SPV_ENV_UNIVERSAL_1_0), - Eq("Invalid Opcode name 'OpGetKernelLocalSizeForSubgroupCount'")); +TEST_F(OpGetKernelMaxNumSubgroupsTest, OpcodeAssemblesInV10) { + EXPECT_THAT( + CompiledInstructions("%res = OpGetKernelMaxNumSubgroups %type " + "%invoke %param %param_size %param_align", + SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpGetKernelMaxNumSubgroups, {1, 2, 3, 4, 5, 6}))); } TEST_F(OpGetKernelMaxNumSubgroupsTest, ArgumentCount) { @@ -114,4 +118,5 @@ TEST_F(OpGetKernelMaxNumSubgroupsTest, ArgumentTypes) { Eq("Expected id to start with %.")); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary.type_declaration_test.cpp b/3rdparty/spirv-tools/test/text_to_binary.type_declaration_test.cpp index be8d6916e..c6f158f29 100644 --- a/3rdparty/spirv-tools/test/text_to_binary.type_declaration_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary.type_declaration_test.cpp @@ -15,11 +15,14 @@ // Assembler tests for instructions in the "Type-Declaration" section of the // SPIR-V spec. -#include "unit_spirv.h" +#include +#include #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { using spvtest::EnumCase; @@ -235,9 +238,12 @@ TEST_F(OpTypeForwardPointerTest, WrongClass) { using OpSizeOfTest = spvtest::TextToBinaryTest; -TEST_F(OpSizeOfTest, OpcodeUnrecognizedInV10) { - EXPECT_THAT(CompileFailure("%1 = OpSizeOf %2 %3", SPV_ENV_UNIVERSAL_1_0), - Eq("Invalid Opcode name 'OpSizeOf'")); +// We should be able to assemble it. Validation checks are in another test +// file. +TEST_F(OpSizeOfTest, OpcodeAssemblesInV10) { + EXPECT_THAT( + CompiledInstructions("%1 = OpSizeOf %2 %3", SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpSizeOf, {1, 2, 3}))); } TEST_F(OpSizeOfTest, ArgumentCount) { @@ -283,4 +289,5 @@ TEST_F(OpSizeOfTest, ArgumentTypes) { // TODO(dneto): OpTypeReserveId // TODO(dneto): OpTypeQueue -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_to_binary_test.cpp b/3rdparty/spirv-tools/test/text_to_binary_test.cpp index e9a709b6f..4ba37ad4d 100644 --- a/3rdparty/spirv-tools/test/text_to_binary_test.cpp +++ b/3rdparty/spirv-tools/test/text_to_binary_test.cpp @@ -14,21 +14,20 @@ #include #include +#include #include #include #include "gmock/gmock.h" - #include "source/spirv_constant.h" #include "source/util/bitutils.h" #include "source/util/hex_float.h" -#include "test_fixture.h" -#include "unit_spirv.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +namespace spvtools { namespace { -using libspirv::AssemblyContext; -using libspirv::AssemblyGrammar; using spvtest::AutoText; using spvtest::Concatenate; using spvtest::MakeInstruction; @@ -266,4 +265,5 @@ TEST(CreateContext, VulkanEnvironment) { spvContextDestroy(c); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/text_word_get_test.cpp b/3rdparty/spirv-tools/test/text_word_get_test.cpp index dccc7133a..b74a680fa 100644 --- a/3rdparty/spirv-tools/test/text_word_get_test.cpp +++ b/3rdparty/spirv-tools/test/text_word_get_test.cpp @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include +#include "test/unit_spirv.h" + +namespace spvtools { namespace { -using libspirv::AssemblyContext; using spvtest::AutoText; #define TAB "\t" @@ -248,4 +250,5 @@ TEST(TextWordGet, CRLF) { EXPECT_STREQ("d", word.c_str()); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/timer_test.cpp b/3rdparty/spirv-tools/test/timer_test.cpp index 0002ae626..e53af6653 100644 --- a/3rdparty/spirv-tools/test/timer_test.cpp +++ b/3rdparty/spirv-tools/test/timer_test.cpp @@ -18,13 +18,10 @@ #include "gtest/gtest.h" #include "source/util/timer.h" +namespace spvtools { +namespace utils { namespace { -using ::spvutils::CumulativeTimer; -using ::spvutils::PrintTimerDescription; -using ::spvutils::ScopedTimer; -using ::spvutils::Timer; - // A mock class to mimic Timer class for a testing purpose. It has fixed // CPU/WALL/USR/SYS time, RSS delta, and the delta of the number of page faults. class MockTimer : public Timer { @@ -140,4 +137,6 @@ TEST(MockCumulativeTimer, DoNothing) { if (ctimer) delete ctimer; } -} // anonymous namespace +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/tools/CMakeLists.txt b/3rdparty/spirv-tools/test/tools/CMakeLists.txt new file mode 100644 index 000000000..cee95cadb --- /dev/null +++ b/3rdparty/spirv-tools/test/tools/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +spirv_add_nosetests(expect) +spirv_add_nosetests(spirv_test_framework) + +add_subdirectory(opt) diff --git a/3rdparty/spirv-tools/test/tools/expect.py b/3rdparty/spirv-tools/test/tools/expect.py new file mode 100755 index 000000000..c9596506a --- /dev/null +++ b/3rdparty/spirv-tools/test/tools/expect.py @@ -0,0 +1,677 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A number of common spirv result checks coded in mixin classes. + +A test case can use these checks by declaring their enclosing mixin classes +as superclass and providing the expected_* variables required by the check_*() +methods in the mixin classes. +""" +import difflib +import os +import re +import subprocess +from spirv_test_framework import SpirvTest + + +def convert_to_unix_line_endings(source): + """Converts all line endings in source to be unix line endings.""" + return source.replace('\r\n', '\n').replace('\r', '\n') + + +def substitute_file_extension(filename, extension): + """Substitutes file extension, respecting known shader extensions. + + foo.vert -> foo.vert.[extension] [similarly for .frag, .comp, etc.] + foo.glsl -> foo.[extension] + foo.unknown -> foo.[extension] + foo -> foo.[extension] + """ + if filename[-5:] not in [ + '.vert', '.frag', '.tesc', '.tese', '.geom', '.comp', '.spvasm' + ]: + return filename.rsplit('.', 1)[0] + '.' + extension + else: + return filename + '.' + extension + + +def get_object_filename(source_filename): + """Gets the object filename for the given source file.""" + return substitute_file_extension(source_filename, 'spv') + + +def get_assembly_filename(source_filename): + """Gets the assembly filename for the given source file.""" + return substitute_file_extension(source_filename, 'spvasm') + + +def verify_file_non_empty(filename): + """Checks that a given file exists and is not empty.""" + if not os.path.isfile(filename): + return False, 'Cannot find file: ' + filename + if not os.path.getsize(filename): + return False, 'Empty file: ' + filename + return True, '' + + +class ReturnCodeIsZero(SpirvTest): + """Mixin class for checking that the return code is zero.""" + + def check_return_code_is_zero(self, status): + if status.returncode: + return False, 'Non-zero return code: {ret}\n'.format( + ret=status.returncode) + return True, '' + + +class NoOutputOnStdout(SpirvTest): + """Mixin class for checking that there is no output on stdout.""" + + def check_no_output_on_stdout(self, status): + if status.stdout: + return False, 'Non empty stdout: {out}\n'.format(out=status.stdout) + return True, '' + + +class NoOutputOnStderr(SpirvTest): + """Mixin class for checking that there is no output on stderr.""" + + def check_no_output_on_stderr(self, status): + if status.stderr: + return False, 'Non empty stderr: {err}\n'.format(err=status.stderr) + return True, '' + + +class SuccessfulReturn(ReturnCodeIsZero, NoOutputOnStdout, NoOutputOnStderr): + """Mixin class for checking that return code is zero and no output on + stdout and stderr.""" + pass + + +class NoGeneratedFiles(SpirvTest): + """Mixin class for checking that there is no file generated.""" + + def check_no_generated_files(self, status): + all_files = os.listdir(status.directory) + input_files = status.input_filenames + if all([f.startswith(status.directory) for f in input_files]): + all_files = [os.path.join(status.directory, f) for f in all_files] + generated_files = set(all_files) - set(input_files) + if len(generated_files) == 0: + return True, '' + else: + return False, 'Extra files generated: {}'.format(generated_files) + + +class CorrectBinaryLengthAndPreamble(SpirvTest): + """Provides methods for verifying preamble for a SPIR-V binary.""" + + def verify_binary_length_and_header(self, binary, spv_version=0x10000): + """Checks that the given SPIR-V binary has valid length and header. + + Returns: + False, error string if anything is invalid + True, '' otherwise + Args: + binary: a bytes object containing the SPIR-V binary + spv_version: target SPIR-V version number, with same encoding + as the version word in a SPIR-V header. + """ + + def read_word(binary, index, little_endian): + """Reads the index-th word from the given binary file.""" + word = binary[index * 4:(index + 1) * 4] + if little_endian: + word = reversed(word) + return reduce(lambda w, b: (w << 8) | ord(b), word, 0) + + def check_endianness(binary): + """Checks the endianness of the given SPIR-V binary. + + Returns: + True if it's little endian, False if it's big endian. + None if magic number is wrong. + """ + first_word = read_word(binary, 0, True) + if first_word == 0x07230203: + return True + first_word = read_word(binary, 0, False) + if first_word == 0x07230203: + return False + return None + + num_bytes = len(binary) + if num_bytes % 4 != 0: + return False, ('Incorrect SPV binary: size should be a multiple' + ' of words') + if num_bytes < 20: + return False, 'Incorrect SPV binary: size less than 5 words' + + preamble = binary[0:19] + little_endian = check_endianness(preamble) + # SPIR-V module magic number + if little_endian is None: + return False, 'Incorrect SPV binary: wrong magic number' + + # SPIR-V version number + version = read_word(preamble, 1, little_endian) + # TODO(dneto): Recent Glslang uses version word 0 for opengl_compat + # profile + + if version != spv_version and version != 0: + return False, 'Incorrect SPV binary: wrong version number' + # Shaderc-over-Glslang (0x000d....) or + # SPIRV-Tools (0x0007....) generator number + if read_word(preamble, 2, little_endian) != 0x000d0007 and \ + read_word(preamble, 2, little_endian) != 0x00070000: + return False, ('Incorrect SPV binary: wrong generator magic ' 'number') + # reserved for instruction schema + if read_word(preamble, 4, little_endian) != 0: + return False, 'Incorrect SPV binary: the 5th byte should be 0' + + return True, '' + + +class CorrectObjectFilePreamble(CorrectBinaryLengthAndPreamble): + """Provides methods for verifying preamble for a SPV object file.""" + + def verify_object_file_preamble(self, filename, spv_version=0x10000): + """Checks that the given SPIR-V binary file has correct preamble.""" + + success, message = verify_file_non_empty(filename) + if not success: + return False, message + + with open(filename, 'rb') as object_file: + object_file.seek(0, os.SEEK_END) + num_bytes = object_file.tell() + + object_file.seek(0) + + binary = bytes(object_file.read()) + return self.verify_binary_length_and_header(binary, spv_version) + + return True, '' + + +class CorrectAssemblyFilePreamble(SpirvTest): + """Provides methods for verifying preamble for a SPV assembly file.""" + + def verify_assembly_file_preamble(self, filename): + success, message = verify_file_non_empty(filename) + if not success: + return False, message + + with open(filename) as assembly_file: + line1 = assembly_file.readline() + line2 = assembly_file.readline() + line3 = assembly_file.readline() + + if (line1 != '; SPIR-V\n' or line2 != '; Version: 1.0\n' or + (not line3.startswith('; Generator: Google Shaderc over Glslang;'))): + return False, 'Incorrect SPV assembly' + + return True, '' + + +class ValidObjectFile(SuccessfulReturn, CorrectObjectFilePreamble): + """Mixin class for checking that every input file generates a valid SPIR-V 1.0 + object file following the object file naming rule, and there is no output on + stdout/stderr.""" + + def check_object_file_preamble(self, status): + for input_filename in status.input_filenames: + object_filename = get_object_filename(input_filename) + success, message = self.verify_object_file_preamble( + os.path.join(status.directory, object_filename)) + if not success: + return False, message + return True, '' + + +class ValidObjectFile1_3(ReturnCodeIsZero, CorrectObjectFilePreamble): + """Mixin class for checking that every input file generates a valid SPIR-V 1.3 + object file following the object file naming rule, and there is no output on + stdout/stderr.""" + + def check_object_file_preamble(self, status): + for input_filename in status.input_filenames: + object_filename = get_object_filename(input_filename) + success, message = self.verify_object_file_preamble( + os.path.join(status.directory, object_filename), 0x10300) + if not success: + return False, message + return True, '' + + +class ValidObjectFileWithAssemblySubstr(SuccessfulReturn, + CorrectObjectFilePreamble): + """Mixin class for checking that every input file generates a valid object + + file following the object file naming rule, there is no output on + stdout/stderr, and the disassmbly contains a specified substring per + input. + """ + + def check_object_file_disassembly(self, status): + for an_input in status.inputs: + object_filename = get_object_filename(an_input.filename) + obj_file = str(os.path.join(status.directory, object_filename)) + success, message = self.verify_object_file_preamble(obj_file) + if not success: + return False, message + cmd = [status.test_manager.disassembler_path, '--no-color', obj_file] + process = subprocess.Popen( + args=cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=status.directory) + output = process.communicate(None) + disassembly = output[0] + if not isinstance(an_input.assembly_substr, str): + return False, 'Missing assembly_substr member' + if an_input.assembly_substr not in disassembly: + return False, ('Incorrect disassembly output:\n{asm}\n' + 'Expected substring not found:\n{exp}'.format( + asm=disassembly, exp=an_input.assembly_substr)) + return True, '' + + +class ValidNamedObjectFile(SuccessfulReturn, CorrectObjectFilePreamble): + """Mixin class for checking that a list of object files with the given + names are correctly generated, and there is no output on stdout/stderr. + + To mix in this class, subclasses need to provide expected_object_filenames + as the expected object filenames. + """ + + def check_object_file_preamble(self, status): + for object_filename in self.expected_object_filenames: + success, message = self.verify_object_file_preamble( + os.path.join(status.directory, object_filename)) + if not success: + return False, message + return True, '' + + +class ValidFileContents(SpirvTest): + """Mixin class to test that a specific file contains specific text + To mix in this class, subclasses need to provide expected_file_contents as + the contents of the file and target_filename to determine the location.""" + + def check_file(self, status): + target_filename = os.path.join(status.directory, self.target_filename) + if not os.path.isfile(target_filename): + return False, 'Cannot find file: ' + target_filename + with open(target_filename, 'r') as target_file: + file_contents = target_file.read() + if isinstance(self.expected_file_contents, str): + if file_contents == self.expected_file_contents: + return True, '' + return False, ('Incorrect file output: \n{act}\n' + 'Expected:\n{exp}' + 'With diff:\n{diff}'.format( + act=file_contents, + exp=self.expected_file_contents, + diff='\n'.join( + list( + difflib.unified_diff( + self.expected_file_contents.split('\n'), + file_contents.split('\n'), + fromfile='expected_output', + tofile='actual_output'))))) + elif isinstance(self.expected_file_contents, type(re.compile(''))): + if self.expected_file_contents.search(file_contents): + return True, '' + return False, ('Incorrect file output: \n{act}\n' + 'Expected matching regex pattern:\n{exp}'.format( + act=file_contents, + exp=self.expected_file_contents.pattern)) + return False, ( + 'Could not open target file ' + target_filename + ' for reading') + + +class ValidAssemblyFile(SuccessfulReturn, CorrectAssemblyFilePreamble): + """Mixin class for checking that every input file generates a valid assembly + file following the assembly file naming rule, and there is no output on + stdout/stderr.""" + + def check_assembly_file_preamble(self, status): + for input_filename in status.input_filenames: + assembly_filename = get_assembly_filename(input_filename) + success, message = self.verify_assembly_file_preamble( + os.path.join(status.directory, assembly_filename)) + if not success: + return False, message + return True, '' + + +class ValidAssemblyFileWithSubstr(ValidAssemblyFile): + """Mixin class for checking that every input file generates a valid assembly + file following the assembly file naming rule, there is no output on + stdout/stderr, and all assembly files have the given substring specified + by expected_assembly_substr. + + To mix in this class, subclasses need to provde expected_assembly_substr + as the expected substring. + """ + + def check_assembly_with_substr(self, status): + for input_filename in status.input_filenames: + assembly_filename = get_assembly_filename(input_filename) + success, message = self.verify_assembly_file_preamble( + os.path.join(status.directory, assembly_filename)) + if not success: + return False, message + with open(assembly_filename, 'r') as f: + content = f.read() + if self.expected_assembly_substr not in convert_to_unix_line_endings( + content): + return False, ('Incorrect assembly output:\n{asm}\n' + 'Expected substring not found:\n{exp}'.format( + asm=content, exp=self.expected_assembly_substr)) + return True, '' + + +class ValidAssemblyFileWithoutSubstr(ValidAssemblyFile): + """Mixin class for checking that every input file generates a valid assembly + file following the assembly file naming rule, there is no output on + stdout/stderr, and no assembly files have the given substring specified + by unexpected_assembly_substr. + + To mix in this class, subclasses need to provde unexpected_assembly_substr + as the substring we expect not to see. + """ + + def check_assembly_for_substr(self, status): + for input_filename in status.input_filenames: + assembly_filename = get_assembly_filename(input_filename) + success, message = self.verify_assembly_file_preamble( + os.path.join(status.directory, assembly_filename)) + if not success: + return False, message + with open(assembly_filename, 'r') as f: + content = f.read() + if self.unexpected_assembly_substr in convert_to_unix_line_endings( + content): + return False, ('Incorrect assembly output:\n{asm}\n' + 'Unexpected substring found:\n{unexp}'.format( + asm=content, exp=self.unexpected_assembly_substr)) + return True, '' + + +class ValidNamedAssemblyFile(SuccessfulReturn, CorrectAssemblyFilePreamble): + """Mixin class for checking that a list of assembly files with the given + names are correctly generated, and there is no output on stdout/stderr. + + To mix in this class, subclasses need to provide expected_assembly_filenames + as the expected assembly filenames. + """ + + def check_object_file_preamble(self, status): + for assembly_filename in self.expected_assembly_filenames: + success, message = self.verify_assembly_file_preamble( + os.path.join(status.directory, assembly_filename)) + if not success: + return False, message + return True, '' + + +class ErrorMessage(SpirvTest): + """Mixin class for tests that fail with a specific error message. + + To mix in this class, subclasses need to provide expected_error as the + expected error message. + + The test should fail if the subprocess was terminated by a signal. + """ + + def check_has_error_message(self, status): + if not status.returncode: + return False, ('Expected error message, but returned success from ' + 'command execution') + if status.returncode < 0: + # On Unix, a negative value -N for Popen.returncode indicates + # termination by signal N. + # https://docs.python.org/2/library/subprocess.html + return False, ('Expected error message, but command was terminated by ' + 'signal ' + str(status.returncode)) + if not status.stderr: + return False, 'Expected error message, but no output on stderr' + if self.expected_error != convert_to_unix_line_endings(status.stderr): + return False, ('Incorrect stderr output:\n{act}\n' + 'Expected:\n{exp}'.format( + act=status.stderr, exp=self.expected_error)) + return True, '' + + +class ErrorMessageSubstr(SpirvTest): + """Mixin class for tests that fail with a specific substring in the error + message. + + To mix in this class, subclasses need to provide expected_error_substr as + the expected error message substring. + + The test should fail if the subprocess was terminated by a signal. + """ + + def check_has_error_message_as_substring(self, status): + if not status.returncode: + return False, ('Expected error message, but returned success from ' + 'command execution') + if status.returncode < 0: + # On Unix, a negative value -N for Popen.returncode indicates + # termination by signal N. + # https://docs.python.org/2/library/subprocess.html + return False, ('Expected error message, but command was terminated by ' + 'signal ' + str(status.returncode)) + if not status.stderr: + return False, 'Expected error message, but no output on stderr' + if self.expected_error_substr not in convert_to_unix_line_endings( + status.stderr): + return False, ('Incorrect stderr output:\n{act}\n' + 'Expected substring not found in stderr:\n{exp}'.format( + act=status.stderr, exp=self.expected_error_substr)) + return True, '' + + +class WarningMessage(SpirvTest): + """Mixin class for tests that succeed but have a specific warning message. + + To mix in this class, subclasses need to provide expected_warning as the + expected warning message. + """ + + def check_has_warning_message(self, status): + if status.returncode: + return False, ('Expected warning message, but returned failure from' + ' command execution') + if not status.stderr: + return False, 'Expected warning message, but no output on stderr' + if self.expected_warning != convert_to_unix_line_endings(status.stderr): + return False, ('Incorrect stderr output:\n{act}\n' + 'Expected:\n{exp}'.format( + act=status.stderr, exp=self.expected_warning)) + return True, '' + + +class ValidObjectFileWithWarning(NoOutputOnStdout, CorrectObjectFilePreamble, + WarningMessage): + """Mixin class for checking that every input file generates a valid object + file following the object file naming rule, with a specific warning message. + """ + + def check_object_file_preamble(self, status): + for input_filename in status.input_filenames: + object_filename = get_object_filename(input_filename) + success, message = self.verify_object_file_preamble( + os.path.join(status.directory, object_filename)) + if not success: + return False, message + return True, '' + + +class ValidAssemblyFileWithWarning(NoOutputOnStdout, + CorrectAssemblyFilePreamble, WarningMessage): + """Mixin class for checking that every input file generates a valid assembly + file following the assembly file naming rule, with a specific warning + message.""" + + def check_assembly_file_preamble(self, status): + for input_filename in status.input_filenames: + assembly_filename = get_assembly_filename(input_filename) + success, message = self.verify_assembly_file_preamble( + os.path.join(status.directory, assembly_filename)) + if not success: + return False, message + return True, '' + + +class StdoutMatch(SpirvTest): + """Mixin class for tests that can expect output on stdout. + + To mix in this class, subclasses need to provide expected_stdout as the + expected stdout output. + + For expected_stdout, if it's True, then they expect something on stdout but + will not check what it is. If it's a string, expect an exact match. If it's + anything else, it is assumed to be a compiled regular expression which will + be matched against re.search(). It will expect + expected_stdout.search(status.stdout) to be true. + """ + + def check_stdout_match(self, status): + # "True" in this case means we expect something on stdout, but we do not + # care what it is, we want to distinguish this from "blah" which means we + # expect exactly the string "blah". + if self.expected_stdout is True: + if not status.stdout: + return False, 'Expected something on stdout' + elif type(self.expected_stdout) == str: + if self.expected_stdout != convert_to_unix_line_endings(status.stdout): + return False, ('Incorrect stdout output:\n{ac}\n' + 'Expected:\n{ex}'.format( + ac=status.stdout, ex=self.expected_stdout)) + else: + if not self.expected_stdout.search( + convert_to_unix_line_endings(status.stdout)): + return False, ('Incorrect stdout output:\n{ac}\n' + 'Expected to match regex:\n{ex}'.format( + ac=status.stdout, ex=self.expected_stdout.pattern)) + return True, '' + + +class StderrMatch(SpirvTest): + """Mixin class for tests that can expect output on stderr. + + To mix in this class, subclasses need to provide expected_stderr as the + expected stderr output. + + For expected_stderr, if it's True, then they expect something on stderr, + but will not check what it is. If it's a string, expect an exact match. + If it's anything else, it is assumed to be a compiled regular expression + which will be matched against re.search(). It will expect + expected_stderr.search(status.stderr) to be true. + """ + + def check_stderr_match(self, status): + # "True" in this case means we expect something on stderr, but we do not + # care what it is, we want to distinguish this from "blah" which means we + # expect exactly the string "blah". + if self.expected_stderr is True: + if not status.stderr: + return False, 'Expected something on stderr' + elif type(self.expected_stderr) == str: + if self.expected_stderr != convert_to_unix_line_endings(status.stderr): + return False, ('Incorrect stderr output:\n{ac}\n' + 'Expected:\n{ex}'.format( + ac=status.stderr, ex=self.expected_stderr)) + else: + if not self.expected_stderr.search( + convert_to_unix_line_endings(status.stderr)): + return False, ('Incorrect stderr output:\n{ac}\n' + 'Expected to match regex:\n{ex}'.format( + ac=status.stderr, ex=self.expected_stderr.pattern)) + return True, '' + + +class StdoutNoWiderThan80Columns(SpirvTest): + """Mixin class for tests that require stdout to 80 characters or narrower. + + To mix in this class, subclasses need to provide expected_stdout as the + expected stdout output. + """ + + def check_stdout_not_too_wide(self, status): + if not status.stdout: + return True, '' + else: + for line in status.stdout.splitlines(): + if len(line) > 80: + return False, ('Stdout line longer than 80 columns: %s' % line) + return True, '' + + +class NoObjectFile(SpirvTest): + """Mixin class for checking that no input file has a corresponding object + file.""" + + def check_no_object_file(self, status): + for input_filename in status.input_filenames: + object_filename = get_object_filename(input_filename) + full_object_file = os.path.join(status.directory, object_filename) + print('checking %s' % full_object_file) + if os.path.isfile(full_object_file): + return False, ( + 'Expected no object file, but found: %s' % full_object_file) + return True, '' + + +class NoNamedOutputFiles(SpirvTest): + """Mixin class for checking that no specified output files exist. + + The expected_output_filenames member should be full pathnames.""" + + def check_no_named_output_files(self, status): + for object_filename in self.expected_output_filenames: + if os.path.isfile(object_filename): + return False, ( + 'Expected no output file, but found: %s' % object_filename) + return True, '' + + +class ExecutedListOfPasses(SpirvTest): + """Mixin class for checking that a list of passes where executed. + + It works by analyzing the output of the --print-all flag to spirv-opt. + + For this mixin to work, the class member expected_passes should be a sequence + of pass names as returned by Pass::name(). + """ + + def check_list_of_executed_passes(self, status): + # Collect all the output lines containing a pass name. + pass_names = [] + pass_name_re = re.compile(r'.*IR before pass (?P[\S]+)') + for line in status.stderr.splitlines(): + match = pass_name_re.match(line) + if match: + pass_names.append(match.group('pass_name')) + + for (expected, actual) in zip(self.expected_passes, pass_names): + if expected != actual: + return False, ( + 'Expected pass "%s" but found pass "%s"\n' % (expected, actual)) + + return True, '' diff --git a/3rdparty/spirv-tools/test/tools/expect_nosetest.py b/3rdparty/spirv-tools/test/tools/expect_nosetest.py new file mode 100755 index 000000000..b591a2d07 --- /dev/null +++ b/3rdparty/spirv-tools/test/tools/expect_nosetest.py @@ -0,0 +1,80 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the expect module.""" + +import expect +from spirv_test_framework import TestStatus +from nose.tools import assert_equal, assert_true, assert_false +import re + + +def nosetest_get_object_name(): + """Tests get_object_filename().""" + source_and_object_names = [('a.vert', 'a.vert.spv'), ('b.frag', 'b.frag.spv'), + ('c.tesc', 'c.tesc.spv'), ('d.tese', 'd.tese.spv'), + ('e.geom', 'e.geom.spv'), ('f.comp', 'f.comp.spv'), + ('file', 'file.spv'), ('file.', 'file.spv'), + ('file.uk', + 'file.spv'), ('file.vert.', + 'file.vert.spv'), ('file.vert.bla', + 'file.vert.spv')] + actual_object_names = [ + expect.get_object_filename(f[0]) for f in source_and_object_names + ] + expected_object_names = [f[1] for f in source_and_object_names] + + assert_equal(actual_object_names, expected_object_names) + + +class TestStdoutMatchADotC(expect.StdoutMatch): + expected_stdout = re.compile('a.c') + + +def nosetest_stdout_match_regex_has_match(): + test = TestStdoutMatchADotC() + status = TestStatus( + test_manager=None, + returncode=0, + stdout='0abc1', + stderr=None, + directory=None, + inputs=None, + input_filenames=None) + assert_true(test.check_stdout_match(status)[0]) + + +def nosetest_stdout_match_regex_no_match(): + test = TestStdoutMatchADotC() + status = TestStatus( + test_manager=None, + returncode=0, + stdout='ab', + stderr=None, + directory=None, + inputs=None, + input_filenames=None) + assert_false(test.check_stdout_match(status)[0]) + + +def nosetest_stdout_match_regex_empty_stdout(): + test = TestStdoutMatchADotC() + status = TestStatus( + test_manager=None, + returncode=0, + stdout='', + stderr=None, + directory=None, + inputs=None, + input_filenames=None) + assert_false(test.check_stdout_match(status)[0]) diff --git a/3rdparty/spirv-tools/test/tools/opt/CMakeLists.txt b/3rdparty/spirv-tools/test/tools/opt/CMakeLists.txt new file mode 100644 index 000000000..a6dc5262d --- /dev/null +++ b/3rdparty/spirv-tools/test/tools/opt/CMakeLists.txt @@ -0,0 +1,25 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT ${SPIRV_SKIP_TESTS}) + if(${PYTHONINTERP_FOUND}) + add_test(NAME spirv_opt_tests + COMMAND ${PYTHON_EXECUTABLE} + ${CMAKE_CURRENT_SOURCE_DIR}/../spirv_test_framework.py + $ $ $ + --test-dir ${CMAKE_CURRENT_SOURCE_DIR}) + else() + message("Skipping CLI tools tests - Python executable not found") + endif() +endif() diff --git a/3rdparty/spirv-tools/test/tools/opt/flags.py b/3rdparty/spirv-tools/test/tools/opt/flags.py new file mode 100644 index 000000000..628d87108 --- /dev/null +++ b/3rdparty/spirv-tools/test/tools/opt/flags.py @@ -0,0 +1,330 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import placeholder +import expect +import re + +from spirv_test_framework import inside_spirv_testsuite + + +def empty_main_assembly(): + return """ + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %4 "main" + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd""" + + +@inside_spirv_testsuite('SpirvOptBase') +class TestAssemblyFileAsOnlyParameter(expect.ValidObjectFile1_3): + """Tests that spirv-opt accepts a SPIR-V object file.""" + + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + output = placeholder.TempFileName('output.spv') + spirv_args = [shader, '-o', output] + expected_object_filenames = (output) + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestHelpFlag(expect.ReturnCodeIsZero, expect.StdoutMatch): + """Test the --help flag.""" + + spirv_args = ['--help'] + expected_stdout = re.compile(r'.*The SPIR-V binary is read from ') + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestValidPassFlags(expect.ValidObjectFile1_3, + expect.ExecutedListOfPasses): + """Tests that spirv-opt accepts all valid optimization flags.""" + + flags = [ + '--ccp', '--cfg-cleanup', '--combine-access-chains', '--compact-ids', + '--convert-local-access-chains', '--copy-propagate-arrays', + '--eliminate-common-uniform', '--eliminate-dead-branches', + '--eliminate-dead-code-aggressive', '--eliminate-dead-const', + '--eliminate-dead-functions', '--eliminate-dead-inserts', + '--eliminate-dead-variables', '--eliminate-insert-extract', + '--eliminate-local-multi-store', '--eliminate-local-single-block', + '--eliminate-local-single-store', '--flatten-decorations', + '--fold-spec-const-op-composite', '--freeze-spec-const', + '--if-conversion', '--inline-entry-points-exhaustive', '--loop-fission', + '20', '--loop-fusion', '5', '--loop-unroll', '--loop-unroll-partial', '3', + '--loop-peeling', '--merge-blocks', '--merge-return', '--loop-unswitch', + '--private-to-local', '--reduce-load-size', '--redundancy-elimination', + '--remove-duplicates', '--replace-invalid-opcode', '--ssa-rewrite', + '--scalar-replacement', '--scalar-replacement=42', '--strength-reduction', + '--strip-debug', '--strip-reflect', '--vector-dce', '--workaround-1209', + '--unify-const' + ] + expected_passes = [ + 'ccp', + 'cfg-cleanup', + 'combine-access-chains', + 'compact-ids', + 'convert-local-access-chains', + 'copy-propagate-arrays', + 'eliminate-common-uniform', + 'eliminate-dead-branches', + 'eliminate-dead-code-aggressive', + 'eliminate-dead-const', + 'eliminate-dead-functions', + 'eliminate-dead-inserts', + 'eliminate-dead-variables', + # --eliminate-insert-extract runs the simplify-instructions pass. + 'simplify-instructions', + 'eliminate-local-multi-store', + 'eliminate-local-single-block', + 'eliminate-local-single-store', + 'flatten-decorations', + 'fold-spec-const-op-composite', + 'freeze-spec-const', + 'if-conversion', + 'inline-entry-points-exhaustive', + 'loop-fission', + 'loop-fusion', + 'loop-unroll', + 'loop-unroll', + 'loop-peeling', + 'merge-blocks', + 'merge-return', + 'loop-unswitch', + 'private-to-local', + 'reduce-load-size', + 'redundancy-elimination', + 'remove-duplicates', + 'replace-invalid-opcode', + 'ssa-rewrite', + 'scalar-replacement=100', + 'scalar-replacement=42', + 'strength-reduction', + 'strip-debug', + 'strip-reflect', + 'vector-dce', + 'workaround-1209', + 'unify-const' + ] + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + output = placeholder.TempFileName('output.spv') + spirv_args = [shader, '-o', output, '--print-all'] + flags + expected_object_filenames = (output) + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestPerformanceOptimizationPasses(expect.ValidObjectFile1_3, + expect.ExecutedListOfPasses): + """Tests that spirv-opt schedules all the passes triggered by -O.""" + + flags = ['-O'] + expected_passes = [ + 'merge-return', + 'inline-entry-points-exhaustive', + 'eliminate-dead-code-aggressive', + 'private-to-local', + 'eliminate-local-single-block', + 'eliminate-local-single-store', + 'eliminate-dead-code-aggressive', + 'scalar-replacement=100', + 'convert-local-access-chains', + 'eliminate-local-single-block', + 'eliminate-local-single-store', + 'eliminate-dead-code-aggressive', + 'eliminate-local-multi-store', + 'eliminate-dead-code-aggressive', + 'ccp', + 'eliminate-dead-code-aggressive', + 'redundancy-elimination', + 'combine-access-chains', + 'simplify-instructions', + 'vector-dce', + 'eliminate-dead-inserts', + 'eliminate-dead-branches', + 'simplify-instructions', + 'if-conversion', + 'copy-propagate-arrays', + 'reduce-load-size', + 'eliminate-dead-code-aggressive', + 'merge-blocks', + 'redundancy-elimination', + 'eliminate-dead-branches', + 'merge-blocks', + 'simplify-instructions', + ] + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + output = placeholder.TempFileName('output.spv') + spirv_args = [shader, '-o', output, '--print-all'] + flags + expected_object_filenames = (output) + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestSizeOptimizationPasses(expect.ValidObjectFile1_3, + expect.ExecutedListOfPasses): + """Tests that spirv-opt schedules all the passes triggered by -Os.""" + + flags = ['-Os'] + expected_passes = [ + 'merge-return', + 'inline-entry-points-exhaustive', + 'eliminate-dead-code-aggressive', + 'private-to-local', + 'scalar-replacement=100', + 'convert-local-access-chains', + 'eliminate-local-single-block', + 'eliminate-local-single-store', + 'eliminate-dead-code-aggressive', + 'simplify-instructions', + 'eliminate-dead-inserts', + 'eliminate-local-multi-store', + 'eliminate-dead-code-aggressive', + 'ccp', + 'eliminate-dead-code-aggressive', + 'eliminate-dead-branches', + 'if-conversion', + 'eliminate-dead-code-aggressive', + 'merge-blocks', + 'simplify-instructions', + 'eliminate-dead-inserts', + 'redundancy-elimination', + 'cfg-cleanup', + 'eliminate-dead-code-aggressive', + ] + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + output = placeholder.TempFileName('output.spv') + spirv_args = [shader, '-o', output, '--print-all'] + flags + expected_object_filenames = (output) + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLegalizationPasses(expect.ValidObjectFile1_3, + expect.ExecutedListOfPasses): + """Tests that spirv-opt schedules all the passes triggered by --legalize-hlsl. + """ + + flags = ['--legalize-hlsl'] + expected_passes = [ + 'eliminate-dead-branches', + 'merge-return', + 'inline-entry-points-exhaustive', + 'eliminate-dead-functions', + 'private-to-local', + 'eliminate-local-single-block', + 'eliminate-local-single-store', + 'eliminate-dead-code-aggressive', + 'scalar-replacement=0', + 'eliminate-local-single-block', + 'eliminate-local-single-store', + 'eliminate-dead-code-aggressive', + 'eliminate-local-multi-store', + 'eliminate-dead-code-aggressive', + 'ccp', + 'eliminate-dead-branches', + 'simplify-instructions', + 'eliminate-dead-code-aggressive', + 'copy-propagate-arrays', + 'vector-dce', + 'eliminate-dead-inserts', + 'reduce-load-size', + 'eliminate-dead-code-aggressive', + ] + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + output = placeholder.TempFileName('output.spv') + spirv_args = [shader, '-o', output, '--print-all'] + flags + expected_object_filenames = (output) + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestScalarReplacementArgsNegative(expect.ErrorMessageSubstr): + """Tests invalid arguments to --scalar-replacement.""" + + spirv_args = ['--scalar-replacement=-10'] + expected_error_substr = 'must have no arguments or a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestScalarReplacementArgsInvalidNumber(expect.ErrorMessageSubstr): + """Tests invalid arguments to --scalar-replacement.""" + + spirv_args = ['--scalar-replacement=a10f'] + expected_error_substr = 'must have no arguments or a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopFissionArgsNegative(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-fission.""" + + spirv_args = ['--loop-fission=-10'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopFissionArgsInvalidNumber(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-fission.""" + + spirv_args = ['--loop-fission=a10f'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopFusionArgsNegative(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-fusion.""" + + spirv_args = ['--loop-fusion=-10'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopFusionArgsInvalidNumber(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-fusion.""" + + spirv_args = ['--loop-fusion=a10f'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopUnrollPartialArgsNegative(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-unroll-partial.""" + + spirv_args = ['--loop-unroll-partial=-10'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopUnrollPartialArgsInvalidNumber(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-unroll-partial.""" + + spirv_args = ['--loop-unroll-partial=a10f'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopPeelingThresholdArgsNegative(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-peeling-threshold.""" + + spirv_args = ['--loop-peeling-threshold=-10'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopPeelingThresholdArgsInvalidNumber(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-peeling-threshold.""" + + spirv_args = ['--loop-peeling-threshold=a10f'] + expected_error_substr = 'must have a positive integer argument' diff --git a/3rdparty/spirv-tools/test/tools/opt/oconfig.py b/3rdparty/spirv-tools/test/tools/opt/oconfig.py new file mode 100644 index 000000000..337237994 --- /dev/null +++ b/3rdparty/spirv-tools/test/tools/opt/oconfig.py @@ -0,0 +1,58 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import placeholder +import expect +import re + +from spirv_test_framework import inside_spirv_testsuite + + +def empty_main_assembly(): + return """ + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %4 "main" + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd""" + + +@inside_spirv_testsuite('SpirvOptConfigFile') +class TestOconfigEmpty(expect.SuccessfulReturn): + """Tests empty config files are accepted.""" + + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + config = placeholder.ConfigFlagsFile('', '.cfg') + spirv_args = [shader, '-o', placeholder.TempFileName('output.spv'), config] + + +@inside_spirv_testsuite('SpirvOptConfigFile') +class TestOconfigComments(expect.SuccessfulReturn): + """Tests empty config files are accepted. + + https://github.com/KhronosGroup/SPIRV-Tools/issues/1778 + """ + + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + config = placeholder.ConfigFlagsFile(""" +# This is a comment. +-O +--loop-unroll +""", '.cfg') + spirv_args = [shader, '-o', placeholder.TempFileName('output.spv'), config] diff --git a/3rdparty/spirv-tools/test/tools/placeholder.py b/3rdparty/spirv-tools/test/tools/placeholder.py new file mode 100755 index 000000000..7de3c467a --- /dev/null +++ b/3rdparty/spirv-tools/test/tools/placeholder.py @@ -0,0 +1,213 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A number of placeholders and their rules for expansion when used in tests. + +These placeholders, when used in spirv_args or expected_* variables of +SpirvTest, have special meanings. In spirv_args, they will be substituted by +the result of instantiate_for_spirv_args(), while in expected_*, by +instantiate_for_expectation(). A TestCase instance will be passed in as +argument to the instantiate_*() methods. +""" + +import os +import subprocess +import tempfile +from string import Template + + +class PlaceHolderException(Exception): + """Exception class for PlaceHolder.""" + pass + + +class PlaceHolder(object): + """Base class for placeholders.""" + + def instantiate_for_spirv_args(self, testcase): + """Instantiation rules for spirv_args. + + This method will be called when the current placeholder appears in + spirv_args. + + Returns: + A string to replace the current placeholder in spirv_args. + """ + raise PlaceHolderException('Subclass should implement this function.') + + def instantiate_for_expectation(self, testcase): + """Instantiation rules for expected_*. + + This method will be called when the current placeholder appears in + expected_*. + + Returns: + A string to replace the current placeholder in expected_*. + """ + raise PlaceHolderException('Subclass should implement this function.') + + +class FileShader(PlaceHolder): + """Stands for a shader whose source code is in a file.""" + + def __init__(self, source, suffix, assembly_substr=None): + assert isinstance(source, str) + assert isinstance(suffix, str) + self.source = source + self.suffix = suffix + self.filename = None + # If provided, this is a substring which is expected to be in + # the disassembly of the module generated from this input file. + self.assembly_substr = assembly_substr + + def instantiate_for_spirv_args(self, testcase): + """Creates a temporary file and writes the source into it. + + Returns: + The name of the temporary file. + """ + shader, self.filename = tempfile.mkstemp( + dir=testcase.directory, suffix=self.suffix) + shader_object = os.fdopen(shader, 'w') + shader_object.write(self.source) + shader_object.close() + return self.filename + + def instantiate_for_expectation(self, testcase): + assert self.filename is not None + return self.filename + + +class ConfigFlagsFile(PlaceHolder): + """Stands for a configuration file for spirv-opt generated out of a string.""" + + def __init__(self, content, suffix): + assert isinstance(content, str) + assert isinstance(suffix, str) + self.content = content + self.suffix = suffix + self.filename = None + + def instantiate_for_spirv_args(self, testcase): + """Creates a temporary file and writes content into it. + + Returns: + The name of the temporary file. + """ + temp_fd, self.filename = tempfile.mkstemp( + dir=testcase.directory, suffix=self.suffix) + fd = os.fdopen(temp_fd, 'w') + fd.write(self.content) + fd.close() + return '-Oconfig=%s' % self.filename + + def instantiate_for_expectation(self, testcase): + assert self.filename is not None + return self.filename + + +class FileSPIRVShader(PlaceHolder): + """Stands for a source shader file which must be converted to SPIR-V.""" + + def __init__(self, source, suffix, assembly_substr=None): + assert isinstance(source, str) + assert isinstance(suffix, str) + self.source = source + self.suffix = suffix + self.filename = None + # If provided, this is a substring which is expected to be in + # the disassembly of the module generated from this input file. + self.assembly_substr = assembly_substr + + def instantiate_for_spirv_args(self, testcase): + """Creates a temporary file, writes the source into it and assembles it. + + Returns: + The name of the assembled temporary file. + """ + shader, asm_filename = tempfile.mkstemp( + dir=testcase.directory, suffix=self.suffix) + shader_object = os.fdopen(shader, 'w') + shader_object.write(self.source) + shader_object.close() + self.filename = '%s.spv' % asm_filename + cmd = [ + testcase.test_manager.assembler_path, asm_filename, '-o', self.filename + ] + process = subprocess.Popen( + args=cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=testcase.directory) + output = process.communicate() + assert process.returncode == 0 and not output[0] and not output[1] + return self.filename + + def instantiate_for_expectation(self, testcase): + assert self.filename is not None + return self.filename + + +class StdinShader(PlaceHolder): + """Stands for a shader whose source code is from stdin.""" + + def __init__(self, source): + assert isinstance(source, str) + self.source = source + self.filename = None + + def instantiate_for_spirv_args(self, testcase): + """Writes the source code back to the TestCase instance.""" + testcase.stdin_shader = self.source + self.filename = '-' + return self.filename + + def instantiate_for_expectation(self, testcase): + assert self.filename is not None + return self.filename + + +class TempFileName(PlaceHolder): + """Stands for a temporary file's name.""" + + def __init__(self, filename): + assert isinstance(filename, str) + assert filename != '' + self.filename = filename + + def instantiate_for_spirv_args(self, testcase): + return os.path.join(testcase.directory, self.filename) + + def instantiate_for_expectation(self, testcase): + return os.path.join(testcase.directory, self.filename) + + +class SpecializedString(PlaceHolder): + """Returns a string that has been specialized based on TestCase. + + The string is specialized by expanding it as a string.Template + with all of the specialization being done with each $param replaced + by the associated member on TestCase. + """ + + def __init__(self, filename): + assert isinstance(filename, str) + assert filename != '' + self.filename = filename + + def instantiate_for_spirv_args(self, testcase): + return Template(self.filename).substitute(vars(testcase)) + + def instantiate_for_expectation(self, testcase): + return Template(self.filename).substitute(vars(testcase)) diff --git a/3rdparty/spirv-tools/test/tools/spirv_test_framework.py b/3rdparty/spirv-tools/test/tools/spirv_test_framework.py new file mode 100755 index 000000000..03ad08fa8 --- /dev/null +++ b/3rdparty/spirv-tools/test/tools/spirv_test_framework.py @@ -0,0 +1,375 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Manages and runs tests from the current working directory. + +This will traverse the current working directory and look for python files that +contain subclasses of SpirvTest. + +If a class has an @inside_spirv_testsuite decorator, an instance of that +class will be created and serve as a test case in that testsuite. The test +case is then run by the following steps: + + 1. A temporary directory will be created. + 2. The spirv_args member variable will be inspected and all placeholders in it + will be expanded by calling instantiate_for_spirv_args() on placeholders. + The transformed list elements are then supplied as arguments to the spirv-* + tool under test. + 3. If the environment member variable exists, its write() method will be + invoked. + 4. All expected_* member variables will be inspected and all placeholders in + them will be expanded by calling instantiate_for_expectation() on those + placeholders. After placeholder expansion, if the expected_* variable is + a list, its element will be joined together with '' to form a single + string. These expected_* variables are to be used by the check_*() methods. + 5. The spirv-* tool will be run with the arguments supplied in spirv_args. + 6. All check_*() member methods will be called by supplying a TestStatus as + argument. Each check_*() method is expected to return a (Success, Message) + pair where Success is a boolean indicating success and Message is an error + message. + 7. If any check_*() method fails, the error message is output and the + current test case fails. + +If --leave-output was not specified, all temporary files and directories will +be deleted. +""" + +from __future__ import print_function + +import argparse +import fnmatch +import inspect +import os +import shutil +import subprocess +import sys +import tempfile +from collections import defaultdict +from placeholder import PlaceHolder + +EXPECTED_BEHAVIOR_PREFIX = 'expected_' +VALIDATE_METHOD_PREFIX = 'check_' + + +def get_all_variables(instance): + """Returns the names of all the variables in instance.""" + return [v for v in dir(instance) if not callable(getattr(instance, v))] + + +def get_all_methods(instance): + """Returns the names of all methods in instance.""" + return [m for m in dir(instance) if callable(getattr(instance, m))] + + +def get_all_superclasses(cls): + """Returns all superclasses of a given class. + + Returns: + A list of superclasses of the given class. The order guarantees that + * A Base class precedes its derived classes, e.g., for "class B(A)", it + will be [..., A, B, ...]. + * When there are multiple base classes, base classes declared first + precede those declared later, e.g., for "class C(A, B), it will be + [..., A, B, C, ...] + """ + classes = [] + for superclass in cls.__bases__: + for c in get_all_superclasses(superclass): + if c not in classes: + classes.append(c) + for superclass in cls.__bases__: + if superclass not in classes: + classes.append(superclass) + return classes + + +def get_all_test_methods(test_class): + """Gets all validation methods. + + Returns: + A list of validation methods. The order guarantees that + * A method defined in superclass precedes one defined in subclass, + e.g., for "class A(B)", methods defined in B precedes those defined + in A. + * If a subclass has more than one superclass, e.g., "class C(A, B)", + then methods defined in A precedes those defined in B. + """ + classes = get_all_superclasses(test_class) + classes.append(test_class) + all_tests = [ + m for c in classes for m in get_all_methods(c) + if m.startswith(VALIDATE_METHOD_PREFIX) + ] + unique_tests = [] + for t in all_tests: + if t not in unique_tests: + unique_tests.append(t) + return unique_tests + + +class SpirvTest: + """Base class for spirv test cases. + + Subclasses define test cases' facts (shader source code, spirv command, + result validation), which will be used by the TestCase class for running + tests. Subclasses should define spirv_args (specifying spirv_tool command + arguments), and at least one check_*() method (for result validation) for + a full-fledged test case. All check_*() methods should take a TestStatus + parameter and return a (Success, Message) pair, in which Success is a + boolean indicating success and Message is an error message. The test passes + iff all check_*() methods returns true. + + Often, a test case class will delegate the check_* behaviors by inheriting + from other classes. + """ + + def name(self): + return self.__class__.__name__ + + +class TestStatus: + """A struct for holding run status of a test case.""" + + def __init__(self, test_manager, returncode, stdout, stderr, directory, + inputs, input_filenames): + self.test_manager = test_manager + self.returncode = returncode + self.stdout = stdout + self.stderr = stderr + # temporary directory where the test runs + self.directory = directory + # List of inputs, as PlaceHolder objects. + self.inputs = inputs + # the names of input shader files (potentially including paths) + self.input_filenames = input_filenames + + +class SpirvTestException(Exception): + """SpirvTest exception class.""" + pass + + +def inside_spirv_testsuite(testsuite_name): + """Decorator for subclasses of SpirvTest. + + This decorator checks that a class meets the requirements (see below) + for a test case class, and then puts the class in a certain testsuite. + * The class needs to be a subclass of SpirvTest. + * The class needs to have spirv_args defined as a list. + * The class needs to define at least one check_*() methods. + * All expected_* variables required by check_*() methods can only be + of bool, str, or list type. + * Python runtime will throw an exception if the expected_* member + attributes required by check_*() methods are missing. + """ + + def actual_decorator(cls): + if not inspect.isclass(cls): + raise SpirvTestException('Test case should be a class') + if not issubclass(cls, SpirvTest): + raise SpirvTestException( + 'All test cases should be subclasses of SpirvTest') + if 'spirv_args' not in get_all_variables(cls): + raise SpirvTestException('No spirv_args found in the test case') + if not isinstance(cls.spirv_args, list): + raise SpirvTestException('spirv_args needs to be a list') + if not any( + [m.startswith(VALIDATE_METHOD_PREFIX) for m in get_all_methods(cls)]): + raise SpirvTestException('No check_*() methods found in the test case') + if not all( + [isinstance(v, (bool, str, list)) for v in get_all_variables(cls)]): + raise SpirvTestException( + 'expected_* variables are only allowed to be bool, str, or ' + 'list type.') + cls.parent_testsuite = testsuite_name + return cls + + return actual_decorator + + +class TestManager: + """Manages and runs a set of tests.""" + + def __init__(self, executable_path, assembler_path, disassembler_path): + self.executable_path = executable_path + self.assembler_path = assembler_path + self.disassembler_path = disassembler_path + self.num_successes = 0 + self.num_failures = 0 + self.num_tests = 0 + self.leave_output = False + self.tests = defaultdict(list) + + def notify_result(self, test_case, success, message): + """Call this to notify the manager of the results of a test run.""" + self.num_successes += 1 if success else 0 + self.num_failures += 0 if success else 1 + counter_string = str(self.num_successes + self.num_failures) + '/' + str( + self.num_tests) + print('%-10s %-40s ' % (counter_string, test_case.test.name()) + + ('Passed' if success else '-Failed-')) + if not success: + print(' '.join(test_case.command)) + print(message) + + def add_test(self, testsuite, test): + """Add this to the current list of test cases.""" + self.tests[testsuite].append(TestCase(test, self)) + self.num_tests += 1 + + def run_tests(self): + for suite in self.tests: + print('SPIRV tool test suite: "{suite}"'.format(suite=suite)) + for x in self.tests[suite]: + x.runTest() + + +class TestCase: + """A single test case that runs in its own directory.""" + + def __init__(self, test, test_manager): + self.test = test + self.test_manager = test_manager + self.inputs = [] # inputs, as PlaceHolder objects. + self.file_shaders = [] # filenames of shader files. + self.stdin_shader = None # text to be passed to spirv_tool as stdin + + def setUp(self): + """Creates environment and instantiates placeholders for the test case.""" + + self.directory = tempfile.mkdtemp(dir=os.getcwd()) + spirv_args = self.test.spirv_args + # Instantiate placeholders in spirv_args + self.test.spirv_args = [ + arg.instantiate_for_spirv_args(self) + if isinstance(arg, PlaceHolder) else arg for arg in self.test.spirv_args + ] + # Get all shader files' names + self.inputs = [arg for arg in spirv_args if isinstance(arg, PlaceHolder)] + self.file_shaders = [arg.filename for arg in self.inputs] + + if 'environment' in get_all_variables(self.test): + self.test.environment.write(self.directory) + + expectations = [ + v for v in get_all_variables(self.test) + if v.startswith(EXPECTED_BEHAVIOR_PREFIX) + ] + # Instantiate placeholders in expectations + for expectation_name in expectations: + expectation = getattr(self.test, expectation_name) + if isinstance(expectation, list): + expanded_expections = [ + element.instantiate_for_expectation(self) + if isinstance(element, PlaceHolder) else element + for element in expectation + ] + setattr(self.test, expectation_name, expanded_expections) + elif isinstance(expectation, PlaceHolder): + setattr(self.test, expectation_name, + expectation.instantiate_for_expectation(self)) + + def tearDown(self): + """Removes the directory if we were not instructed to do otherwise.""" + if not self.test_manager.leave_output: + shutil.rmtree(self.directory) + + def runTest(self): + """Sets up and runs a test, reports any failures and then cleans up.""" + self.setUp() + success = False + message = '' + try: + self.command = [self.test_manager.executable_path] + self.command.extend(self.test.spirv_args) + + process = subprocess.Popen( + args=self.command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=self.directory) + output = process.communicate(self.stdin_shader) + test_status = TestStatus(self.test_manager, process.returncode, output[0], + output[1], self.directory, self.inputs, + self.file_shaders) + run_results = [ + getattr(self.test, test_method)(test_status) + for test_method in get_all_test_methods(self.test.__class__) + ] + success, message = zip(*run_results) + success = all(success) + message = '\n'.join(message) + except Exception as e: + success = False + message = str(e) + self.test_manager.notify_result( + self, success, + message + '\nSTDOUT:\n%s\nSTDERR:\n%s' % (output[0], output[1])) + self.tearDown() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + 'spirv_tool', + metavar='path/to/spirv_tool', + type=str, + nargs=1, + help='Path to the spirv-* tool under test') + parser.add_argument( + 'spirv_as', + metavar='path/to/spirv-as', + type=str, + nargs=1, + help='Path to spirv-as') + parser.add_argument( + 'spirv_dis', + metavar='path/to/spirv-dis', + type=str, + nargs=1, + help='Path to spirv-dis') + parser.add_argument( + '--leave-output', + action='store_const', + const=1, + help='Do not clean up temporary directories') + parser.add_argument( + '--test-dir', nargs=1, help='Directory to gather the tests from') + args = parser.parse_args() + default_path = sys.path + root_dir = os.getcwd() + if args.test_dir: + root_dir = args.test_dir[0] + manager = TestManager(args.spirv_tool[0], args.spirv_as[0], args.spirv_dis[0]) + if args.leave_output: + manager.leave_output = True + for root, _, filenames in os.walk(root_dir): + for filename in fnmatch.filter(filenames, '*.py'): + if filename.endswith('nosetest.py'): + # Skip nose tests, which are for testing functions of + # the test framework. + continue + sys.path = default_path + sys.path.append(root) + mod = __import__(os.path.splitext(filename)[0]) + for _, obj, in inspect.getmembers(mod): + if inspect.isclass(obj) and hasattr(obj, 'parent_testsuite'): + manager.add_test(obj.parent_testsuite, obj()) + manager.run_tests() + if manager.num_failures > 0: + sys.exit(-1) + + +if __name__ == '__main__': + main() diff --git a/3rdparty/spirv-tools/test/tools/spirv_test_framework_nosetest.py b/3rdparty/spirv-tools/test/tools/spirv_test_framework_nosetest.py new file mode 100755 index 000000000..c0fbed581 --- /dev/null +++ b/3rdparty/spirv-tools/test/tools/spirv_test_framework_nosetest.py @@ -0,0 +1,155 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from spirv_test_framework import get_all_test_methods, get_all_superclasses +from nose.tools import assert_equal, with_setup + + +# Classes to be used in testing get_all_{superclasses|test_methods}() +class Root: + + def check_root(self): + pass + + +class A(Root): + + def check_a(self): + pass + + +class B(Root): + + def check_b(self): + pass + + +class C(Root): + + def check_c(self): + pass + + +class D(Root): + + def check_d(self): + pass + + +class E(Root): + + def check_e(self): + pass + + +class H(B, C, D): + + def check_h(self): + pass + + +class I(E): + + def check_i(self): + pass + + +class O(H, I): + + def check_o(self): + pass + + +class U(A, O): + + def check_u(self): + pass + + +class X(U, A): + + def check_x(self): + pass + + +class R1: + + def check_r1(self): + pass + + +class R2: + + def check_r2(self): + pass + + +class Multi(R1, R2): + + def check_multi(self): + pass + + +def nosetest_get_all_superclasses(): + """Tests get_all_superclasses().""" + + assert_equal(get_all_superclasses(A), [Root]) + assert_equal(get_all_superclasses(B), [Root]) + assert_equal(get_all_superclasses(C), [Root]) + assert_equal(get_all_superclasses(D), [Root]) + assert_equal(get_all_superclasses(E), [Root]) + + assert_equal(get_all_superclasses(H), [Root, B, C, D]) + assert_equal(get_all_superclasses(I), [Root, E]) + + assert_equal(get_all_superclasses(O), [Root, B, C, D, E, H, I]) + + assert_equal(get_all_superclasses(U), [Root, B, C, D, E, H, I, A, O]) + assert_equal(get_all_superclasses(X), [Root, B, C, D, E, H, I, A, O, U]) + + assert_equal(get_all_superclasses(Multi), [R1, R2]) + + +def nosetest_get_all_methods(): + """Tests get_all_test_methods().""" + assert_equal(get_all_test_methods(A), ['check_root', 'check_a']) + assert_equal(get_all_test_methods(B), ['check_root', 'check_b']) + assert_equal(get_all_test_methods(C), ['check_root', 'check_c']) + assert_equal(get_all_test_methods(D), ['check_root', 'check_d']) + assert_equal(get_all_test_methods(E), ['check_root', 'check_e']) + + assert_equal( + get_all_test_methods(H), + ['check_root', 'check_b', 'check_c', 'check_d', 'check_h']) + assert_equal(get_all_test_methods(I), ['check_root', 'check_e', 'check_i']) + + assert_equal( + get_all_test_methods(O), [ + 'check_root', 'check_b', 'check_c', 'check_d', 'check_e', 'check_h', + 'check_i', 'check_o' + ]) + + assert_equal( + get_all_test_methods(U), [ + 'check_root', 'check_b', 'check_c', 'check_d', 'check_e', 'check_h', + 'check_i', 'check_a', 'check_o', 'check_u' + ]) + assert_equal( + get_all_test_methods(X), [ + 'check_root', 'check_b', 'check_c', 'check_d', 'check_e', 'check_h', + 'check_i', 'check_a', 'check_o', 'check_u', 'check_x' + ]) + + assert_equal( + get_all_test_methods(Multi), ['check_r1', 'check_r2', 'check_multi']) diff --git a/3rdparty/spirv-tools/test/unit_spirv.cpp b/3rdparty/spirv-tools/test/unit_spirv.cpp index c2a770a88..84ed87a51 100644 --- a/3rdparty/spirv-tools/test/unit_spirv.cpp +++ b/3rdparty/spirv-tools/test/unit_spirv.cpp @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unit_spirv.h" +#include "test/unit_spirv.h" #include "gmock/gmock.h" -#include "test_fixture.h" +#include "test/test_fixture.h" +namespace spvtools { namespace { using spvtest::MakeVector; @@ -50,4 +51,5 @@ TEST_P(RoundTripTest, Sample) { << GetParam(); } -} // anonymous namespace +} // namespace +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/unit_spirv.h b/3rdparty/spirv-tools/test/unit_spirv.h index 45e8c2379..224428884 100644 --- a/3rdparty/spirv-tools/test/unit_spirv.h +++ b/3rdparty/spirv-tools/test/unit_spirv.h @@ -12,14 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_TEST_UNITSPIRV_H_ -#define LIBSPIRV_TEST_UNITSPIRV_H_ +#ifndef TEST_UNIT_SPIRV_H_ +#define TEST_UNIT_SPIRV_H_ #include #include +#include #include +#include "gtest/gtest.h" #include "source/assembly_grammar.h" #include "source/binary.h" #include "source/diagnostic.h" @@ -28,11 +30,9 @@ #include "source/spirv_endian.h" #include "source/text.h" #include "source/text_handler.h" -#include "source/validate.h" +#include "source/val/validate.h" #include "spirv-tools/libspirv.h" -#include - #ifdef __ANDROID__ #include namespace std { @@ -218,17 +218,17 @@ inline std::vector AllTargetEnvironments() { SPV_ENV_OPENGL_4_1, SPV_ENV_OPENGL_4_2, SPV_ENV_OPENGL_4_3, SPV_ENV_OPENGL_4_5, SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3, - SPV_ENV_VULKAN_1_1, + SPV_ENV_VULKAN_1_1, SPV_ENV_WEBGPU_0, }; } // Returns the capabilities in a CapabilitySet as an ordered vector. inline std::vector ElementsIn( - const libspirv::CapabilitySet& capabilities) { + const spvtools::CapabilitySet& capabilities) { std::vector result; capabilities.ForEach([&result](SpvCapability c) { result.push_back(c); }); return result; } } // namespace spvtest -#endif // LIBSPIRV_TEST_UNITSPIRV_H_ +#endif // TEST_UNIT_SPIRV_H_ diff --git a/3rdparty/spirv-tools/test/util/CMakeLists.txt b/3rdparty/spirv-tools/test/util/CMakeLists.txt index 9b6ca2c1f..66d4e8a42 100644 --- a/3rdparty/spirv-tools/test/util/CMakeLists.txt +++ b/3rdparty/spirv-tools/test/util/CMakeLists.txt @@ -14,5 +14,13 @@ add_spvtools_unittest(TARGET util_intrusive_list SRCS ilist_test.cpp +) + +add_spvtools_unittest(TARGET bit_vector + SRCS bit_vector_test.cpp LIBS SPIRV-Tools-opt ) + +add_spvtools_unittest(TARGET small_vector + SRCS small_vector_test.cpp +) diff --git a/3rdparty/spirv-tools/test/util/bit_vector_test.cpp b/3rdparty/spirv-tools/test/util/bit_vector_test.cpp new file mode 100644 index 000000000..8d967f8f9 --- /dev/null +++ b/3rdparty/spirv-tools/test/util/bit_vector_test.cpp @@ -0,0 +1,164 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" + +#include "source/util/bit_vector.h" + +namespace spvtools { +namespace utils { +namespace { + +using BitVectorTest = ::testing::Test; + +TEST(BitVectorTest, Initialize) { + BitVector bvec; + + // Checks that all values are 0. Also tests checking a bit past the end of + // the vector containing the bits. + for (int i = 1; i < 10000; i *= 2) { + EXPECT_FALSE(bvec.Get(i)); + } +} + +TEST(BitVectorTest, Set) { + BitVector bvec; + + // Since 10,000 is larger than the initial size, this tests the resizing + // code. + for (int i = 3; i < 10000; i *= 2) { + bvec.Set(i); + } + + // Check that bits that were not set are 0. + for (int i = 1; i < 10000; i *= 2) { + EXPECT_FALSE(bvec.Get(i)); + } + + // Check that bits that were set are 1. + for (int i = 3; i < 10000; i *= 2) { + EXPECT_TRUE(bvec.Get(i)); + } +} + +TEST(BitVectorTest, SetReturnValue) { + BitVector bvec; + + // Make sure |Set| returns false when the bit was not set. + for (int i = 3; i < 10000; i *= 2) { + EXPECT_FALSE(bvec.Set(i)); + } + + // Make sure |Set| returns true when the bit was already set. + for (int i = 3; i < 10000; i *= 2) { + EXPECT_TRUE(bvec.Set(i)); + } +} + +TEST(BitVectorTest, Clear) { + BitVector bvec; + for (int i = 3; i < 10000; i *= 2) { + bvec.Set(i); + } + + // Check that the bits were properly set. + for (int i = 3; i < 10000; i *= 2) { + EXPECT_TRUE(bvec.Get(i)); + } + + // Clear all of the bits except for bit 3. + for (int i = 6; i < 10000; i *= 2) { + bvec.Clear(i); + } + + // Make sure bit 3 was not cleared. + EXPECT_TRUE(bvec.Get(3)); + + // Make sure all of the other bits that were set have been cleared. + for (int i = 6; i < 10000; i *= 2) { + EXPECT_FALSE(bvec.Get(i)); + } +} + +TEST(BitVectorTest, ClearReturnValue) { + BitVector bvec; + for (int i = 3; i < 10000; i *= 2) { + bvec.Set(i); + } + + // Make sure |Clear| returns true if the bit was set. + for (int i = 3; i < 10000; i *= 2) { + EXPECT_TRUE(bvec.Clear(i)); + } + + // Make sure |Clear| returns false if the bit was not set. + for (int i = 3; i < 10000; i *= 2) { + EXPECT_FALSE(bvec.Clear(i)); + } +} + +TEST(BitVectorTest, SimpleOrTest) { + BitVector bvec1; + bvec1.Set(3); + bvec1.Set(4); + + BitVector bvec2; + bvec2.Set(2); + bvec2.Set(4); + + // Check that |bvec1| changed when doing the |Or| operation. + EXPECT_TRUE(bvec1.Or(bvec2)); + + // Check that the values are all correct. + EXPECT_FALSE(bvec1.Get(0)); + EXPECT_FALSE(bvec1.Get(1)); + EXPECT_TRUE(bvec1.Get(2)); + EXPECT_TRUE(bvec1.Get(3)); + EXPECT_TRUE(bvec1.Get(4)); +} + +TEST(BitVectorTest, ResizingOrTest) { + BitVector bvec1; + bvec1.Set(3); + bvec1.Set(4); + + BitVector bvec2; + bvec2.Set(10000); + + // Similar to above except with a large value to test resizing. + EXPECT_TRUE(bvec1.Or(bvec2)); + EXPECT_FALSE(bvec1.Get(0)); + EXPECT_FALSE(bvec1.Get(1)); + EXPECT_FALSE(bvec1.Get(2)); + EXPECT_TRUE(bvec1.Get(3)); + EXPECT_TRUE(bvec1.Get(10000)); +} + +TEST(BitVectorTest, SubsetOrTest) { + BitVector bvec1; + bvec1.Set(3); + bvec1.Set(4); + + BitVector bvec2; + bvec2.Set(3); + + // |Or| returns false if |bvec1| does not change. + EXPECT_FALSE(bvec1.Or(bvec2)); +} + +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/util/ilist_test.cpp b/3rdparty/spirv-tools/test/util/ilist_test.cpp index 5f36f391b..4a546f993 100644 --- a/3rdparty/spirv-tools/test/util/ilist_test.cpp +++ b/3rdparty/spirv-tools/test/util/ilist_test.cpp @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include "gmock/gmock.h" +#include "source/util/ilist.h" -#include "util/ilist.h" - +namespace spvtools { +namespace utils { namespace { -using spvtools::utils::IntrusiveList; -using spvtools::utils::IntrusiveNodeBase; using ::testing::ElementsAre; using IListTest = ::testing::Test; @@ -319,4 +319,7 @@ TEST(IListTest, MoveBefore4) { EXPECT_THAT(output, ElementsAre(0, 1, 2, 3, 4, 5)); } + } // namespace +} // namespace utils +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/util/small_vector_test.cpp b/3rdparty/spirv-tools/test/util/small_vector_test.cpp new file mode 100644 index 000000000..01d7df185 --- /dev/null +++ b/3rdparty/spirv-tools/test/util/small_vector_test.cpp @@ -0,0 +1,598 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "source/util/small_vector.h" + +namespace spvtools { +namespace utils { +namespace { + +using SmallVectorTest = ::testing::Test; + +TEST(SmallVectorTest, Initialize_default) { + SmallVector vec; + + EXPECT_TRUE(vec.empty()); + EXPECT_EQ(vec.size(), 0); + EXPECT_EQ(vec.begin(), vec.end()); +} + +TEST(SmallVectorTest, Initialize_list1) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_FALSE(vec.empty()); + EXPECT_EQ(vec.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec.size(); ++i) { + EXPECT_EQ(vec[i], result[i]); + } +} + +TEST(SmallVectorTest, Initialize_list2) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_FALSE(vec.empty()); + EXPECT_EQ(vec.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec.size(); ++i) { + EXPECT_EQ(vec[i], result[i]); + } +} + +TEST(SmallVectorTest, Initialize_copy1) { + SmallVector vec1 = {0, 1, 2, 3}; + SmallVector vec2(vec1); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + + EXPECT_EQ(vec1, vec2); +} + +TEST(SmallVectorTest, Initialize_copy2) { + SmallVector vec1 = {0, 1, 2, 3}; + SmallVector vec2(vec1); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + + EXPECT_EQ(vec1, vec2); +} + +TEST(SmallVectorTest, Initialize_copy_vec1) { + std::vector vec1 = {0, 1, 2, 3}; + SmallVector vec2(vec1); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + + EXPECT_EQ(vec1, vec2); +} + +TEST(SmallVectorTest, Initialize_copy_vec2) { + std::vector vec1 = {0, 1, 2, 3}; + SmallVector vec2(vec1); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + + EXPECT_EQ(vec1, vec2); +} + +TEST(SmallVectorTest, Initialize_move1) { + SmallVector vec1 = {0, 1, 2, 3}; + SmallVector vec2(std::move(vec1)); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + EXPECT_TRUE(vec1.empty()); +} + +TEST(SmallVectorTest, Initialize_move2) { + SmallVector vec1 = {0, 1, 2, 3}; + SmallVector vec2(std::move(vec1)); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + EXPECT_TRUE(vec1.empty()); +} + +TEST(SmallVectorTest, Initialize_move_vec1) { + std::vector vec1 = {0, 1, 2, 3}; + SmallVector vec2(std::move(vec1)); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + EXPECT_TRUE(vec1.empty()); +} + +TEST(SmallVectorTest, Initialize_move_vec2) { + std::vector vec1 = {0, 1, 2, 3}; + SmallVector vec2(std::move(vec1)); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + EXPECT_TRUE(vec1.empty()); +} + +TEST(SmallVectorTest, Initialize_iterators1) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 4); + uint32_t result[] = {0, 1, 2, 3}; + + uint32_t i = 0; + for (uint32_t p : vec) { + EXPECT_EQ(p, result[i]); + i++; + } +} + +TEST(SmallVectorTest, Initialize_iterators2) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 4); + uint32_t result[] = {0, 1, 2, 3}; + + uint32_t i = 0; + for (uint32_t p : vec) { + EXPECT_EQ(p, result[i]); + i++; + } +} + +TEST(SmallVectorTest, Initialize_iterators3) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 4); + uint32_t result[] = {0, 1, 2, 3}; + + uint32_t i = 0; + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + EXPECT_EQ(*it, result[i]); + i++; + } +} + +TEST(SmallVectorTest, Initialize_iterators4) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 4); + uint32_t result[] = {0, 1, 2, 3}; + + uint32_t i = 0; + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + EXPECT_EQ(*it, result[i]); + i++; + } +} + +TEST(SmallVectorTest, Initialize_iterators_write1) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 4); + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + *it *= 2; + } + + uint32_t result[] = {0, 2, 4, 6}; + + uint32_t i = 0; + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + EXPECT_EQ(*it, result[i]); + i++; + } +} + +TEST(SmallVectorTest, Initialize_iterators_write2) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 4); + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + *it *= 2; + } + + uint32_t result[] = {0, 2, 4, 6}; + + uint32_t i = 0; + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + EXPECT_EQ(*it, result[i]); + i++; + } +} + +TEST(SmallVectorTest, Initialize_front) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.front(), 0); + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + *it += 2; + } + EXPECT_EQ(vec.front(), 2); +} + +TEST(SmallVectorTest, Erase_element_front1) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.front(), 0); + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin()); + EXPECT_EQ(vec.front(), 1); + EXPECT_EQ(vec.size(), 3); +} + +TEST(SmallVectorTest, Erase_element_front2) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.front(), 0); + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin()); + EXPECT_EQ(vec.front(), 1); + EXPECT_EQ(vec.size(), 3); +} + +TEST(SmallVectorTest, Erase_element_back1) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {0, 1, 2}; + + EXPECT_EQ(vec[3], 3); + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin() + 3); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_element_back2) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {0, 1, 2}; + + EXPECT_EQ(vec[3], 3); + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin() + 3); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_element_middle1) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {0, 1, 3}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin() + 2); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_element_middle2) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {0, 1, 3}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin() + 2); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_range_1) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin(), vec.end()); + EXPECT_EQ(vec.size(), 0); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_range_2) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin(), vec.end()); + EXPECT_EQ(vec.size(), 0); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_range_3) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {2, 3}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin(), vec.begin() + 2); + EXPECT_EQ(vec.size(), 2); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_range_4) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {2, 3}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin(), vec.begin() + 2); + EXPECT_EQ(vec.size(), 2); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_range_5) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {0, 3}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin() + 1, vec.begin() + 3); + EXPECT_EQ(vec.size(), 2); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_range_6) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {0, 3}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin() + 1, vec.begin() + 3); + EXPECT_EQ(vec.size(), 2); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Push_back) { + SmallVector vec; + SmallVector result = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 0); + vec.push_back(0); + EXPECT_EQ(vec.size(), 1); + vec.push_back(1); + EXPECT_EQ(vec.size(), 2); + vec.push_back(2); + EXPECT_EQ(vec.size(), 3); + vec.push_back(3); + EXPECT_EQ(vec.size(), 4); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Emplace_back) { + SmallVector vec; + SmallVector result = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 0); + vec.emplace_back(0); + EXPECT_EQ(vec.size(), 1); + vec.emplace_back(1); + EXPECT_EQ(vec.size(), 2); + vec.emplace_back(2); + EXPECT_EQ(vec.size(), 3); + vec.emplace_back(3); + EXPECT_EQ(vec.size(), 4); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Clear) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {}; + + EXPECT_EQ(vec.size(), 4); + vec.clear(); + EXPECT_EQ(vec.size(), 0); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Insert1) { + SmallVector vec = {}; + SmallVector insert_values = {10, 11}; + SmallVector result = {10, 11}; + + EXPECT_EQ(vec.size(), 0); + auto ret = + vec.insert(vec.begin(), insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 2); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert2) { + SmallVector vec = {}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {10, 11, 12}; + + EXPECT_EQ(vec.size(), 0); + auto ret = + vec.insert(vec.begin(), insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert3) { + SmallVector vec = {0}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {10, 11, 12, 0}; + + EXPECT_EQ(vec.size(), 1); + auto ret = + vec.insert(vec.begin(), insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 4); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert4) { + SmallVector vec = {0}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {10, 11, 12, 0}; + + EXPECT_EQ(vec.size(), 1); + auto ret = + vec.insert(vec.begin(), insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 4); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert5) { + SmallVector vec = {0, 1, 2}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {0, 1, 2, 10, 11, 12}; + + EXPECT_EQ(vec.size(), 3); + auto ret = vec.insert(vec.end(), insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert6) { + SmallVector vec = {0, 1, 2}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {0, 1, 2, 10, 11, 12}; + + EXPECT_EQ(vec.size(), 3); + auto ret = vec.insert(vec.end(), insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert7) { + SmallVector vec = {0, 1, 2}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {0, 10, 11, 12, 1, 2}; + + EXPECT_EQ(vec.size(), 3); + auto ret = + vec.insert(vec.begin() + 1, insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert8) { + SmallVector vec = {0, 1, 2}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {0, 10, 11, 12, 1, 2}; + + EXPECT_EQ(vec.size(), 3); + auto ret = + vec.insert(vec.begin() + 1, insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Resize1) { + SmallVector vec = {0, 1, 2}; + SmallVector result = {0, 1, 2, 10, 10, 10}; + + EXPECT_EQ(vec.size(), 3); + vec.resize(6, 10); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Resize2) { + SmallVector vec = {0, 1, 2}; + SmallVector result = {0, 1, 2, 10, 10, 10}; + + EXPECT_EQ(vec.size(), 3); + vec.resize(6, 10); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Resize3) { + SmallVector vec = {0, 1, 2}; + SmallVector result = {0, 1, 2, 10, 10, 10}; + + EXPECT_EQ(vec.size(), 3); + vec.resize(6, 10); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Resize4) { + SmallVector vec = {0, 1, 2, 10, 10, 10}; + SmallVector result = {0, 1, 2}; + + EXPECT_EQ(vec.size(), 6); + vec.resize(3, 10); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Resize5) { + SmallVector vec = {0, 1, 2, 10, 10, 10}; + SmallVector result = {0, 1, 2}; + + EXPECT_EQ(vec.size(), 6); + vec.resize(3, 10); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Resize6) { + SmallVector vec = {0, 1, 2, 10, 10, 10}; + SmallVector result = {0, 1, 2}; + + EXPECT_EQ(vec.size(), 6); + vec.resize(3, 10); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/CMakeLists.txt b/3rdparty/spirv-tools/test/val/CMakeLists.txt index 093a04a4d..b1e87da66 100644 --- a/3rdparty/spirv-tools/test/val/CMakeLists.txt +++ b/3rdparty/spirv-tools/test/val/CMakeLists.txt @@ -18,171 +18,57 @@ set(VAL_TEST_COMMON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/val_fixtures.h ) - -add_spvtools_unittest(TARGET val_capability - SRCS val_capability_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_cfg - SRCS val_cfg_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_id - SRCS val_id_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_layout - SRCS val_layout_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_ssa - SRCS val_ssa_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_storage - SRCS val_storage_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_state - SRCS val_state_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_data - SRCS val_data_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_type_unique - SRCS val_type_unique_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_arithmetics - SRCS val_arithmetics_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_composites - SRCS val_composites_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_conversion - SRCS val_conversion_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_derivatives - SRCS val_derivatives_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_logicals - SRCS val_logicals_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_bitwise - SRCS val_bitwise_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_builtins - SRCS val_builtins_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_image - SRCS val_image_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_atomics - SRCS val_atomics_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_barriers - SRCS val_barriers_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_primitives - SRCS val_primitives_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_ext_inst - SRCS val_ext_inst_test.cpp +add_spvtools_unittest(TARGET val_abcde + SRCS + val_adjacency_test.cpp + val_arithmetics_test.cpp + val_atomics_test.cpp + val_barriers_test.cpp + val_bitwise_test.cpp + val_builtins_test.cpp + val_capability_test.cpp + val_cfg_test.cpp + val_composites_test.cpp + val_conversion_test.cpp + val_data_test.cpp + val_decoration_test.cpp + val_derivatives_test.cpp + val_explicit_reserved_test.cpp + val_extensions_test.cpp + val_ext_inst_test.cpp ${VAL_TEST_COMMON_SRCS} LIBS ${SPIRV_TOOLS} ) add_spvtools_unittest(TARGET val_limits - SRCS val_limits_test.cpp + SRCS val_limits_test.cpp ${VAL_TEST_COMMON_SRCS} LIBS ${SPIRV_TOOLS} ) -add_spvtools_unittest(TARGET val_validation_state - SRCS val_validation_state_test.cpp +add_spvtools_unittest(TARGET val_ijklmnop + SRCS + val_id_test.cpp + val_image_test.cpp + val_interfaces_test.cpp + val_layout_test.cpp + val_literals_test.cpp + val_logicals_test.cpp + val_non_uniform_test.cpp + val_primitives_test.cpp ${VAL_TEST_COMMON_SRCS} LIBS ${SPIRV_TOOLS} ) -add_spvtools_unittest(TARGET val_decoration - SRCS val_decoration_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_literals - SRCS val_literals_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_extensions - SRCS val_extensions_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_adjacency - SRCS val_adjacency_test.cpp - ${VAL_TEST_COMMON_SRCS} - LIBS ${SPIRV_TOOLS} -) - -add_spvtools_unittest(TARGET val_version - SRCS val_version_test.cpp +add_spvtools_unittest(TARGET val_stuvw + SRCS + val_ssa_test.cpp + val_state_test.cpp + val_storage_test.cpp + val_type_unique_test.cpp + val_validation_state_test.cpp + val_version_test.cpp + val_webgpu_test.cpp ${VAL_TEST_COMMON_SRCS} LIBS ${SPIRV_TOOLS} ) diff --git a/3rdparty/spirv-tools/test/val/val_adjacency_test.cpp b/3rdparty/spirv-tools/test/val/val_adjacency_test.cpp index 586cf7e8f..d62830514 100644 --- a/3rdparty/spirv-tools/test/val/val_adjacency_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_adjacency_test.cpp @@ -16,9 +16,11 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::HasSubstr; @@ -282,4 +284,6 @@ OpBranch %merge "OpBranchConditional or OpSwitch instruction")); } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_arithmetics_test.cpp b/3rdparty/spirv-tools/test/val/val_arithmetics_test.cpp index 7c3172f92..1c8d88be5 100644 --- a/3rdparty/spirv-tools/test/val/val_arithmetics_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_arithmetics_test.cpp @@ -17,16 +17,16 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::HasSubstr; using ::testing::Not; -using std::string; - using ValidateArithmetics = spvtest::ValidateBase; std::string GenerateCode(const std::string& main_body) { @@ -1273,4 +1273,6 @@ TEST_F(ValidateArithmetics, SMulExtendedResultTypeMembersNotIdentical) { "SMulExtended")); } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_atomics_test.cpp b/3rdparty/spirv-tools/test/val/val_atomics_test.cpp index 8d98fd371..9aece39aa 100644 --- a/3rdparty/spirv-tools/test/val/val_atomics_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_atomics_test.cpp @@ -16,9 +16,11 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::HasSubstr; @@ -1009,4 +1011,6 @@ OpAtomicStore %u32_var %device %relaxed %u32_1 ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_barriers_test.cpp b/3rdparty/spirv-tools/test/val/val_barriers_test.cpp index e1d510592..38c168eda 100644 --- a/3rdparty/spirv-tools/test/val/val_barriers_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_barriers_test.cpp @@ -16,9 +16,11 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::HasSubstr; @@ -319,9 +321,31 @@ OpControlBarrier %subgroup %subgroup %none CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("ControlBarrier: in Vulkan 1.0 environment Memory Scope is " + "limited to Device, Workgroup and Invocation")); +} + +TEST_F(ValidateBarriers, OpControlBarrierVulkan1p1MemoryScopeSubgroup) { + const std::string body = R"( +OpControlBarrier %subgroup %subgroup %none +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateBarriers, OpControlBarrierVulkan1p1MemoryScopeCrossDevice) { + const std::string body = R"( +OpControlBarrier %subgroup %cross_device %none +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_1)); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ControlBarrier: in Vulkan environment Memory Scope is " - "limited to Device, Workgroup and Invocation")); + HasSubstr("ControlBarrier: in Vulkan environment, Memory Scope " + "cannot be CrossDevice")); } TEST_F(ValidateBarriers, OpControlBarrierAcquireAndRelease) { @@ -353,6 +377,167 @@ OpControlBarrier %workgroup %device %acquire_release_subgroup "Vulkan-supported storage class if Memory Semantics is not None")); } +TEST_F(ValidateBarriers, OpControlBarrierSubgroupExecutionFragment1p1) { + const std::string body = R"( +OpControlBarrier %subgroup %subgroup %acquire_release_subgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Fragment"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateBarriers, OpControlBarrierWorkgroupExecutionFragment1p1) { + const std::string body = R"( +OpControlBarrier %workgroup %workgroup %acquire_release +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Fragment"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpControlBarrier execution scope must be Subgroup for " + "Fragment, Vertex, Geometry and TessellationEvaluation " + "execution models")); +} + +TEST_F(ValidateBarriers, OpControlBarrierSubgroupExecutionFragment1p0) { + const std::string body = R"( +OpControlBarrier %subgroup %workgroup %acquire_release +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Fragment"), + SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpControlBarrier requires one of the following Execution " + "Models: TessellationControl, GLCompute or Kernel")); +} + +TEST_F(ValidateBarriers, OpControlBarrierSubgroupExecutionVertex1p1) { + const std::string body = R"( +OpControlBarrier %subgroup %subgroup %acquire_release_subgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Vertex"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateBarriers, OpControlBarrierWorkgroupExecutionVertex1p1) { + const std::string body = R"( +OpControlBarrier %workgroup %workgroup %acquire_release +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Vertex"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpControlBarrier execution scope must be Subgroup for " + "Fragment, Vertex, Geometry and TessellationEvaluation " + "execution models")); +} + +TEST_F(ValidateBarriers, OpControlBarrierSubgroupExecutionVertex1p0) { + const std::string body = R"( +OpControlBarrier %subgroup %workgroup %acquire_release +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Vertex"), + SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpControlBarrier requires one of the following Execution " + "Models: TessellationControl, GLCompute or Kernel")); +} + +TEST_F(ValidateBarriers, OpControlBarrierSubgroupExecutionGeometry1p1) { + const std::string body = R"( +OpControlBarrier %subgroup %subgroup %acquire_release_subgroup +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability Geometry\n", "Geometry"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateBarriers, OpControlBarrierWorkgroupExecutionGeometry1p1) { + const std::string body = R"( +OpControlBarrier %workgroup %workgroup %acquire_release +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability Geometry\n", "Geometry"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpControlBarrier execution scope must be Subgroup for " + "Fragment, Vertex, Geometry and TessellationEvaluation " + "execution models")); +} + +TEST_F(ValidateBarriers, OpControlBarrierSubgroupExecutionGeometry1p0) { + const std::string body = R"( +OpControlBarrier %subgroup %workgroup %acquire_release +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability Geometry\n", "Geometry"), + SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpControlBarrier requires one of the following Execution " + "Models: TessellationControl, GLCompute or Kernel")); +} + +TEST_F(ValidateBarriers, + OpControlBarrierSubgroupExecutionTessellationEvaluation1p1) { + const std::string body = R"( +OpControlBarrier %subgroup %subgroup %acquire_release_subgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body, "OpCapability Tessellation\n", + "TessellationEvaluation"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateBarriers, + OpControlBarrierWorkgroupExecutionTessellationEvaluation1p1) { + const std::string body = R"( +OpControlBarrier %workgroup %workgroup %acquire_release +)"; + + CompileSuccessfully(GenerateShaderCode(body, "OpCapability Tessellation\n", + "TessellationEvaluation"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpControlBarrier execution scope must be Subgroup for " + "Fragment, Vertex, Geometry and TessellationEvaluation " + "execution models")); +} + +TEST_F(ValidateBarriers, + OpControlBarrierSubgroupExecutionTessellationEvaluation1p0) { + const std::string body = R"( +OpControlBarrier %subgroup %workgroup %acquire_release +)"; + + CompileSuccessfully(GenerateShaderCode(body, "OpCapability Tessellation\n", + "TessellationEvaluation"), + SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpControlBarrier requires one of the following Execution " + "Models: TessellationControl, GLCompute or Kernel")); +} + TEST_F(ValidateBarriers, OpMemoryBarrierSuccess) { const std::string body = R"( OpMemoryBarrier %cross_device %acquire_release_uniform_workgroup @@ -437,9 +622,19 @@ OpMemoryBarrier %subgroup %acquire_release_uniform_workgroup CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("MemoryBarrier: in Vulkan environment Memory Scope is " - "limited to Device, Workgroup and Invocation")); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("MemoryBarrier: in Vulkan 1.0 environment Memory Scope is " + "limited to Device, Workgroup and Invocation")); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierVulkan1p1MemoryScopeSubgroup) { + const std::string body = R"( +OpMemoryBarrier %subgroup %acquire_release_uniform_workgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); } TEST_F(ValidateBarriers, OpMemoryBarrierAcquireAndRelease) { @@ -609,4 +804,18 @@ OpMemoryNamedBarrier %barrier %workgroup %acquire_and_release_uniform "AcquireRelease or SequentiallyConsistent")); } -} // anonymous namespace +TEST_F(ValidateBarriers, TypeAsMemoryScope) { + const std::string body = R"( +OpMemoryBarrier %u32 %u32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("MemoryBarrier: expected Memory Scope to be a 32-bit int")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_bitwise_test.cpp b/3rdparty/spirv-tools/test/val/val_bitwise_test.cpp index 65ce0acb1..18ccd4f8e 100644 --- a/3rdparty/spirv-tools/test/val/val_bitwise_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_bitwise_test.cpp @@ -17,9 +17,11 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::HasSubstr; @@ -541,4 +543,6 @@ TEST_F(ValidateBitwise, OpBitCountBaseWrongDimension) { "BitCount")); } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_builtins_test.cpp b/3rdparty/spirv-tools/test/val/val_builtins_test.cpp index 78e429ddf..0c8909cec 100644 --- a/3rdparty/spirv-tools/test/val/val_builtins_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_builtins_test.cpp @@ -20,11 +20,15 @@ #include #include #include +#include +#include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { struct TestResult { @@ -49,12 +53,15 @@ using ValidateBuiltIns = spvtest::ValidateBase; using ValidateVulkanCombineBuiltInExecutionModelDataTypeResult = spvtest::ValidateBase>; +using ValidateVulkanCombineBuiltInArrayedVariable = spvtest::ValidateBase< + std::tuple>; struct EntryPoint { std::string name; std::string execution_model; std::string execution_modes; std::string body; + std::string interfaces; }; class CodeGenerator { @@ -80,7 +87,8 @@ std::string CodeGenerator::Build() const { for (const EntryPoint& entry_point : entry_points_) { ss << "OpEntryPoint " << entry_point.execution_model << " %" - << entry_point.name << " \"" << entry_point.name << "\"\n"; + << entry_point.name << " \"" << entry_point.name << "\" " + << entry_point.interfaces << "\n"; } for (const EntryPoint& entry_point : entry_points_) { @@ -186,6 +194,10 @@ std::string GetDefaultShaderTypes() { %f64arr2 = OpTypeArray %f64 %u32_2 %f64arr3 = OpTypeArray %f64 %u32_3 %f64arr4 = OpTypeArray %f64 %u32_4 + +%f32vec3arr3 = OpTypeArray %f32vec3 %u32_3 +%f32vec4arr3 = OpTypeArray %f32vec4 %u32_3 +%f64vec4arr3 = OpTypeArray %f64vec4 %u32_3 )"; } @@ -222,6 +234,10 @@ TEST_P(ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, InMain) { EntryPoint entry_point; entry_point.name = "main"; entry_point.execution_model = execution_model; + if (strncmp(storage_class, "Input", 5) == 0 || + strncmp(storage_class, "Output", 6) == 0) { + entry_point.interfaces = "%built_in_var"; + } std::ostringstream execution_modes; if (0 == std::strcmp(execution_model, "Fragment")) { @@ -275,6 +291,10 @@ TEST_P(ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, InFunction) { EntryPoint entry_point; entry_point.name = "main"; entry_point.execution_model = execution_model; + if (strncmp(storage_class, "Input", 5) == 0 || + strncmp(storage_class, "Output", 6) == 0) { + entry_point.interfaces = "%built_in_var"; + } std::ostringstream execution_modes; if (0 == std::strcmp(execution_model, "Fragment")) { @@ -333,6 +353,10 @@ TEST_P(ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, Variable) { EntryPoint entry_point; entry_point.name = "main"; entry_point.execution_model = execution_model; + if (strncmp(storage_class, "Input", 5) == 0 || + strncmp(storage_class, "Output", 6) == 0) { + entry_point.interfaces = "%built_in_var"; + } // Any kind of reference would do. entry_point.body = R"( %val = OpBitcast %u64 %built_in_var @@ -738,14 +762,23 @@ INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P( LayerAndViewportIndexInvalidExecutionModel, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, - Combine( - Values("Layer", "ViewportIndex"), - Values("Vertex", "GLCompute", "TessellationControl", - "TessellationEvaluation"), - Values("Input"), Values("%u32"), - Values(TestResult( - SPV_ERROR_INVALID_DATA, - "to be used only with Fragment or Geometry execution models"))), ); + Combine(Values("Layer", "ViewportIndex"), + Values("TessellationControl", "GLCompute"), Values("Input"), + Values("%u32"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be used only with Vertex, TessellationEvaluation, " + "Geometry, or Fragment execution models"))), ); + +INSTANTIATE_TEST_CASE_P( + LayerAndViewportIndexExecutionModelEnabledByCapability, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Layer", "ViewportIndex"), + Values("Vertex", "TessellationEvaluation"), Values("Output"), + Values("%u32"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "requires the ShaderViewportIndexLayerEXT capability"))), ); INSTANTIATE_TEST_CASE_P( LayerAndViewportIndexFragmentNotInput, @@ -761,11 +794,13 @@ INSTANTIATE_TEST_CASE_P( LayerAndViewportIndexGeometryNotOutput, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, Combine( - Values("Layer", "ViewportIndex"), Values("Geometry"), Values("Input"), + Values("Layer", "ViewportIndex"), + Values("Vertex", "TessellationEvaluation", "Geometry"), Values("Input"), Values("%u32"), Values(TestResult(SPV_ERROR_INVALID_DATA, - "Input storage class if execution model is Geometry", - "which is called with execution model Geometry"))), ); + "Input storage class if execution model is Vertex, " + "TessellationEvaluation, or Geometry", + "which is called with execution model"))), ); INSTANTIATE_TEST_CASE_P( LayerAndViewportIndexNotIntScalar, @@ -1474,6 +1509,129 @@ INSTANTIATE_TEST_CASE_P( "needs to be a 32-bit int scalar", "has bit width 64"))), ); +TEST_P(ValidateVulkanCombineBuiltInArrayedVariable, Variable) { + const char* const built_in = std::get<0>(GetParam()); + const char* const execution_model = std::get<1>(GetParam()); + const char* const storage_class = std::get<2>(GetParam()); + const char* const data_type = std::get<3>(GetParam()); + const TestResult& test_result = std::get<4>(GetParam()); + + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = "OpDecorate %built_in_var BuiltIn "; + generator.before_types_ += built_in; + generator.before_types_ += "\n"; + + std::ostringstream after_types; + after_types << "%built_in_array = OpTypeArray " << data_type << " %u32_3\n"; + after_types << "%built_in_ptr = OpTypePointer " << storage_class + << " %built_in_array\n"; + after_types << "%built_in_var = OpVariable %built_in_ptr " << storage_class + << "\n"; + generator.after_types_ = after_types.str(); + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = execution_model; + entry_point.interfaces = "%built_in_var"; + // Any kind of reference would do. + entry_point.body = R"( +%val = OpBitcast %u64 %built_in_var +)"; + + std::ostringstream execution_modes; + if (0 == std::strcmp(execution_model, "Fragment")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " OriginUpperLeft\n"; + } + if (0 == std::strcmp(built_in, "FragDepth")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " DepthReplacing\n"; + } + entry_point.execution_modes = execution_modes.str(); + + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(test_result.validation_result, + ValidateInstructions(SPV_ENV_VULKAN_1_0)); + if (test_result.error_str) { + EXPECT_THAT(getDiagnosticString(), HasSubstr(test_result.error_str)); + } + if (test_result.error_str2) { + EXPECT_THAT(getDiagnosticString(), HasSubstr(test_result.error_str2)); + } +} + +INSTANTIATE_TEST_CASE_P(PointSizeArrayedF32TessControl, + ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("PointSize"), + Values("TessellationControl"), Values("Input"), + Values("%f32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + PointSizeArrayedF64TessControl, ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("PointSize"), Values("TessellationControl"), Values("Input"), + Values("%f64"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float scalar", + "has bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + PointSizeArrayedF32Vertex, ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("PointSize"), Values("Vertex"), Values("Output"), + Values("%f32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float scalar", + "is not a float scalar"))), ); + +INSTANTIATE_TEST_CASE_P(PositionArrayedF32Vec4TessControl, + ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("Position"), + Values("TessellationControl"), Values("Input"), + Values("%f32vec4"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + PositionArrayedF32Vec3TessControl, + ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("Position"), Values("TessellationControl"), Values("Input"), + Values("%f32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float vector", + "has 3 components"))), ); + +INSTANTIATE_TEST_CASE_P( + PositionArrayedF32Vec4Vertex, ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("Position"), Values("Vertex"), Values("Output"), + Values("%f32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float vector", + "is not a float vector"))), ); + +INSTANTIATE_TEST_CASE_P( + ClipAndCullDistanceOutputSuccess, + ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("ClipDistance", "CullDistance"), + Values("Geometry", "TessellationControl", "TessellationEvaluation"), + Values("Output"), Values("%f32arr2", "%f32arr4"), + Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + ClipAndCullDistanceVertexInput, ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("ClipDistance", "CullDistance"), Values("Fragment"), + Values("Input"), Values("%f32arr4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float array", + "components are not float scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + ClipAndCullDistanceNotArray, ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("ClipDistance", "CullDistance"), + Values("Geometry", "TessellationControl", "TessellationEvaluation"), + Values("Input"), Values("%f32vec2", "%f32vec4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float array", + "components are not float scalar"))), ); + TEST_F(ValidateBuiltIns, WorkgroupSizeSuccess) { CodeGenerator generator = GetDefaultShaderCodeGenerator(); generator.before_types_ = R"( @@ -1656,6 +1814,29 @@ OpDecorate %workgroup_size BuiltIn WorkgroupSize "(OpConstantComposite) has components with bit width 64.")); } +TEST_F(ValidateBuiltIns, WorkgroupSizePrivateVar) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpDecorate %workgroup_size BuiltIn WorkgroupSize +)"; + + generator.after_types_ = R"( +%workgroup_size = OpConstantComposite %u32vec3 %u32_1 %u32_1 %u32_1 +%private_ptr_u32vec3 = OpTypePointer Private %u32vec3 +%var = OpVariable %private_ptr_u32vec3 Private %workgroup_size +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "GLCompute"; + entry_point.body = R"( +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + TEST_F(ValidateBuiltIns, GeometryPositionInOutSuccess) { CodeGenerator generator = GetDefaultShaderCodeGenerator(); @@ -1666,11 +1847,13 @@ OpMemberDecorate %output_type 0 BuiltIn Position generator.after_types_ = R"( %input_type = OpTypeStruct %f32vec4 -%input_ptr = OpTypePointer Input %input_type +%arrayed_input_type = OpTypeArray %input_type %u32_3 +%input_ptr = OpTypePointer Input %arrayed_input_type %input = OpVariable %input_ptr Input %input_f32vec4_ptr = OpTypePointer Input %f32vec4 %output_type = OpTypeStruct %f32vec4 -%output_ptr = OpTypePointer Output %output_type +%arrayed_output_type = OpTypeArray %output_type %u32_3 +%output_ptr = OpTypePointer Output %arrayed_output_type %output = OpVariable %output_ptr Output %output_f32vec4_ptr = OpTypePointer Output %f32vec4 )"; @@ -1678,9 +1861,10 @@ OpMemberDecorate %output_type 0 BuiltIn Position EntryPoint entry_point; entry_point.name = "main"; entry_point.execution_model = "Geometry"; + entry_point.interfaces = "%input %output"; entry_point.body = R"( -%input_pos = OpAccessChain %input_f32vec4_ptr %input %u32_0 -%output_pos = OpAccessChain %output_f32vec4_ptr %output %u32_0 +%input_pos = OpAccessChain %input_f32vec4_ptr %input %u32_0 %u32_0 +%output_pos = OpAccessChain %output_f32vec4_ptr %output %u32_0 %u32_0 %pos = OpLoad %f32vec4 %input_pos OpStore %output_pos %pos )"; @@ -1712,6 +1896,7 @@ OpMemberDecorate %output_type 0 BuiltIn Position EntryPoint entry_point; entry_point.name = "main"; entry_point.execution_model = "Geometry"; + entry_point.interfaces = "%input %output"; entry_point.body = R"( %input_pos = OpAccessChain %input_f32vec4_ptr %input %u32_0 %output_pos = OpAccessChain %output_f32vec4_ptr %output %u32_0 @@ -1749,6 +1934,7 @@ OpMemberDecorate %output_type 0 BuiltIn FragCoord EntryPoint entry_point; entry_point.name = "main"; entry_point.execution_model = "Geometry"; + entry_point.interfaces = "%input %output"; entry_point.body = R"( %input_pos = OpAccessChain %input_f32vec4_ptr %input %u32_0 %output_pos = OpAccessChain %output_f32vec4_ptr %output %u32_0 @@ -1778,6 +1964,7 @@ OpDecorate %position BuiltIn Position EntryPoint entry_point; entry_point.name = "main"; entry_point.execution_model = "Vertex"; + entry_point.interfaces = "%position"; entry_point.body = R"( OpStore %position %f32vec4_0123 )"; @@ -1803,6 +1990,7 @@ OpMemberDecorate %output_type 0 BuiltIn Position EntryPoint entry_point; entry_point.name = "vmain"; entry_point.execution_model = "Vertex"; + entry_point.interfaces = "%output"; entry_point.body = R"( %val1 = OpFunctionCall %void %foo )"; @@ -1810,6 +1998,7 @@ OpMemberDecorate %output_type 0 BuiltIn Position entry_point.name = "fmain"; entry_point.execution_model = "Fragment"; + entry_point.interfaces = "%output"; entry_point.execution_modes = "OpExecutionMode %fmain OriginUpperLeft"; entry_point.body = R"( %val2 = OpFunctionCall %void %foo @@ -1851,6 +2040,7 @@ OpMemberDecorate %output_type 0 BuiltIn FragDepth EntryPoint entry_point; entry_point.name = "main"; entry_point.execution_model = "Fragment"; + entry_point.interfaces = "%output"; entry_point.execution_modes = "OpExecutionMode %main OriginUpperLeft"; entry_point.body = R"( %val2 = OpFunctionCall %void %foo @@ -1889,6 +2079,7 @@ OpMemberDecorate %output_type 0 BuiltIn FragDepth EntryPoint entry_point; entry_point.name = "main_d_r"; entry_point.execution_model = "Fragment"; + entry_point.interfaces = "%output"; entry_point.execution_modes = "OpExecutionMode %main_d_r OriginUpperLeft\n" "OpExecutionMode %main_d_r DepthReplacing"; @@ -1899,6 +2090,7 @@ OpMemberDecorate %output_type 0 BuiltIn FragDepth entry_point.name = "main_no_d_r"; entry_point.execution_model = "Fragment"; + entry_point.interfaces = "%output"; entry_point.execution_modes = "OpExecutionMode %main_no_d_r OriginUpperLeft"; entry_point.body = R"( %val3 = OpFunctionCall %void %foo @@ -1921,4 +2113,6 @@ OpFunctionEnd "be declared when using BuiltIn FragDepth")); } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_capability_test.cpp b/3rdparty/spirv-tools/test/val/val_capability_test.cpp index 1e5b1f03e..2ee7133cd 100644 --- a/3rdparty/spirv-tools/test/val/val_capability_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_capability_test.cpp @@ -18,34 +18,30 @@ #include #include #include +#include -#include - +#include "gmock/gmock.h" #include "source/assembly_grammar.h" -#include "test_fixture.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "source/spirv_target_env.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using spvtest::ScopedContext; -using std::get; -using std::make_pair; -using std::ostringstream; -using std::pair; -using std::string; -using std::tuple; -using std::vector; using testing::Combine; using testing::HasSubstr; using testing::Values; using testing::ValuesIn; -// Parameter for validation test fixtures. The first string is a capability -// name that will begin the assembly under test, the second the remainder -// assembly, and the vector at the end determines whether the test expects -// success or failure. See below for details and convenience methods to access -// each one. +// Parameter for validation test fixtures. The first std::string is a +// capability name that will begin the assembly under test, the second the +// remainder assembly, and the std::vector at the end determines whether the +// test expects success or failure. See below for details and convenience +// methods to access each one. // // The assembly to test is composed from a variable top line and a fixed // remainder. The top line will be an OpCapability instruction, while the @@ -63,27 +59,32 @@ using testing::ValuesIn; // // So how does the test parameter capture which capabilities should cause // success and which shouldn't? The answer is in the last element: it's a -// vector of capabilities that make the remainder assembly succeed. So if the -// first-line capability exists in that vector, success is expected; otherwise, -// failure is expected in the tests. +// std::vector of capabilities that make the remainder assembly succeed. So if +// the first-line capability exists in that std::vector, success is expected; +// otherwise, failure is expected in the tests. // // We will use testing::Combine() to vary the first line: when we combine // AllCapabilities() with a single remainder assembly, we generate enough test // cases to try the assembly with every possible capability that could be // declared. However, Combine() only produces tuples -- it cannot produce, say, // a struct. Therefore, this type must be a tuple. -using CapTestParameter = tuple>>; +using CapTestParameter = + std::tuple>>; -const string& Capability(const CapTestParameter& p) { return get<0>(p); } -const string& Remainder(const CapTestParameter& p) { return get<1>(p).first; } -const vector& MustSucceed(const CapTestParameter& p) { - return get<1>(p).second; +const std::string& Capability(const CapTestParameter& p) { + return std::get<0>(p); +} +const std::string& Remainder(const CapTestParameter& p) { + return std::get<1>(p).first; +} +const std::vector& MustSucceed(const CapTestParameter& p) { + return std::get<1>(p).second; } // Creates assembly to test from p. -string MakeAssembly(const CapTestParameter& p) { - ostringstream ss; - const string& capability = Capability(p); +std::string MakeAssembly(const CapTestParameter& p) { + std::ostringstream ss; + const std::string& capability = Capability(p); if (!capability.empty()) { ss << "OpCapability " << capability << "\n"; } @@ -109,6 +110,8 @@ using ValidateCapabilityV11 = spvtest::ValidateBase; using ValidateCapabilityVulkan10 = spvtest::ValidateBase; // Always assembles using OpenGL 4.0. using ValidateCapabilityOpenGL40 = spvtest::ValidateBase; +// Always assembles using Vulkan 1.1. +using ValidateCapabilityVulkan11 = spvtest::ValidateBase; TEST_F(ValidateCapability, Default) { const char str[] = R"( @@ -126,8 +129,8 @@ TEST_F(ValidateCapability, Default) { } // clang-format off -const vector& AllCapabilities() { - static const auto r = new vector{ +const std::vector& AllCapabilities() { + static const auto r = new std::vector{ "", "Matrix", "Shader", @@ -187,12 +190,31 @@ const vector& AllCapabilities() { "MultiViewport", "SubgroupDispatch", "NamedBarrier", - "PipeStorage"}; + "PipeStorage", + "GroupNonUniform", + "GroupNonUniformVote", + "GroupNonUniformArithmetic", + "GroupNonUniformBallot", + "GroupNonUniformShuffle", + "GroupNonUniformShuffleRelative", + "GroupNonUniformClustered", + "GroupNonUniformQuad", + "DrawParameters", + "StorageBuffer16BitAccess", + "StorageUniformBufferBlock16", + "UniformAndStorageBuffer16BitAccess", + "StorageUniform16", + "StoragePushConstant16", + "StorageInputOutput16", + "DeviceGroup", + "MultiView", + "VariablePointersStorageBuffer", + "VariablePointers"}; return *r; } -const vector& AllSpirV10Capabilities() { - static const auto r = new vector{ +const std::vector& AllSpirV10Capabilities() { + static const auto r = new std::vector{ "", "Matrix", "Shader", @@ -253,8 +275,8 @@ const vector& AllSpirV10Capabilities() { return *r; } -const vector& AllVulkan10Capabilities() { - static const auto r = new vector{ +const std::vector& AllVulkan10Capabilities() { + static const auto r = new std::vector{ "", "Matrix", "Shader", @@ -294,8 +316,68 @@ const vector& AllVulkan10Capabilities() { return *r; } -const vector& MatrixDependencies() { - static const auto r = new vector{ +const std::vector& AllVulkan11Capabilities() { + static const auto r = new std::vector{ + "", + "Matrix", + "Shader", + "InputAttachment", + "Sampled1D", + "Image1D", + "SampledBuffer", + "ImageBuffer", + "ImageQuery", + "DerivativeControl", + "Geometry", + "Tessellation", + "Float64", + "Int64", + "Int16", + "TessellationPointSize", + "GeometryPointSize", + "ImageGatherExtended", + "StorageImageMultisample", + "UniformBufferArrayDynamicIndexing", + "SampledImageArrayDynamicIndexing", + "StorageBufferArrayDynamicIndexing", + "StorageImageArrayDynamicIndexing", + "ClipDistance", + "CullDistance", + "ImageCubeArray", + "SampleRateShading", + "SparseResidency", + "MinLod", + "SampledCubeArray", + "ImageMSArray", + "StorageImageExtendedFormats", + "InterpolationFunction", + "StorageImageReadWithoutFormat", + "StorageImageWriteWithoutFormat", + "MultiViewport", + "GroupNonUniform", + "GroupNonUniformVote", + "GroupNonUniformArithmetic", + "GroupNonUniformBallot", + "GroupNonUniformShuffle", + "GroupNonUniformShuffleRelative", + "GroupNonUniformClustered", + "GroupNonUniformQuad", + "DrawParameters", + "StorageBuffer16BitAccess", + "StorageUniformBufferBlock16", + "UniformAndStorageBuffer16BitAccess", + "StorageUniform16", + "StoragePushConstant16", + "StorageInputOutput16", + "DeviceGroup", + "MultiView", + "VariablePointersStorageBuffer", + "VariablePointers"}; + return *r; +} + +const std::vector& MatrixDependencies() { + static const auto r = new std::vector{ "Matrix", "Shader", "Geometry", @@ -328,12 +410,16 @@ const vector& MatrixDependencies() { "GeometryStreams", "StorageImageReadWithoutFormat", "StorageImageWriteWithoutFormat", - "MultiViewport"}; + "MultiViewport", + "DrawParameters", + "MultiView", + "VariablePointersStorageBuffer", + "VariablePointers"}; return *r; } -const vector& ShaderDependencies() { - static const auto r = new vector{ +const std::vector& ShaderDependencies() { + static const auto r = new std::vector{ "Shader", "Geometry", "Tessellation", @@ -365,19 +451,23 @@ const vector& ShaderDependencies() { "GeometryStreams", "StorageImageReadWithoutFormat", "StorageImageWriteWithoutFormat", - "MultiViewport"}; + "MultiViewport", + "DrawParameters", + "MultiView", + "VariablePointersStorageBuffer", + "VariablePointers"}; return *r; } -const vector& TessellationDependencies() { - static const auto r = new vector{ +const std::vector& TessellationDependencies() { + static const auto r = new std::vector{ "Tessellation", "TessellationPointSize"}; return *r; } -const vector& GeometryDependencies() { - static const auto r = new vector{ +const std::vector& GeometryDependencies() { + static const auto r = new std::vector{ "Geometry", "GeometryPointSize", "GeometryStreams", @@ -385,8 +475,8 @@ const vector& GeometryDependencies() { return *r; } -const vector& GeometryTessellationDependencies() { - static const auto r = new vector{ +const std::vector& GeometryTessellationDependencies() { + static const auto r = new std::vector{ "Tessellation", "TessellationPointSize", "Geometry", @@ -398,8 +488,8 @@ const vector& GeometryTessellationDependencies() { // Returns the names of capabilities that directly depend on Kernel, // plus itself. -const vector& KernelDependencies() { - static const auto r = new vector{ +const std::vector& KernelDependencies() { + static const auto r = new std::vector{ "Kernel", "Vector16", "Float16Buffer", @@ -409,36 +499,60 @@ const vector& KernelDependencies() { "Pipes", "DeviceEnqueue", "LiteralSampler", - "Int8", "SubgroupDispatch", "NamedBarrier", "PipeStorage"}; return *r; } -const vector& AddressesDependencies() { - static const auto r = new vector{ +const std::vector& KernelAndGroupNonUniformDependencies() { + static const auto r = new std::vector{ + "Kernel", + "Vector16", + "Float16Buffer", + "ImageBasic", + "ImageReadWrite", + "ImageMipmap", + "Pipes", + "DeviceEnqueue", + "LiteralSampler", + "SubgroupDispatch", + "NamedBarrier", + "PipeStorage", + "GroupNonUniform", + "GroupNonUniformVote", + "GroupNonUniformArithmetic", + "GroupNonUniformBallot", + "GroupNonUniformShuffle", + "GroupNonUniformShuffleRelative", + "GroupNonUniformClustered", + "GroupNonUniformQuad"}; + return *r; +} + +const std::vector& AddressesDependencies() { + static const auto r = new std::vector{ "Addresses", "GenericPointer"}; return *r; } -const vector& Sampled1DDependencies() { - static const auto r = new vector{ +const std::vector& Sampled1DDependencies() { + static const auto r = new std::vector{ "Sampled1D", "Image1D"}; return *r; } -const vector& SampledRectDependencies() { - static const auto r = new vector{ +const std::vector& SampledRectDependencies() { + static const auto r = new std::vector{ "SampledRect", "ImageRect"}; return *r; } -const vector& SampledBufferDependencies() { - static const auto r = new vector{ +const std::vector& SampledBufferDependencies() { + static const auto r = new std::vector{ "SampledBuffer", "ImageBuffer"}; return *r; @@ -471,289 +585,289 @@ INSTANTIATE_TEST_CASE_P(ExecutionModel, ValidateCapability, Combine( ValuesIn(AllCapabilities()), Values( -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + " OpEntryPoint Vertex %func \"shader\"" + - string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + " OpEntryPoint TessellationControl %func \"shader\"" + - string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + " OpEntryPoint TessellationEvaluation %func \"shader\"" + - string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + " OpEntryPoint Geometry %func \"shader\"" + - string(kVoidFVoid), GeometryDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + " OpEntryPoint Fragment %func \"shader\"" + - string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + " OpEntryPoint GLCompute %func \"shader\"" + - string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kGLSL450MemoryModel) + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + " OpEntryPoint Kernel %func \"shader\"" + - string(kVoidFVoid), KernelDependencies()) + std::string(kVoidFVoid), KernelDependencies()) )),); INSTANTIATE_TEST_CASE_P(AddressingAndMemoryModel, ValidateCapability, Combine( ValuesIn(AllCapabilities()), Values( -make_pair(" OpCapability Shader" +std::make_pair(" OpCapability Shader" " OpMemoryModel Logical Simple" " OpEntryPoint Vertex %func \"shader\"" + - string(kVoidFVoid), AllCapabilities()), -make_pair(" OpCapability Shader" + std::string(kVoidFVoid), AllCapabilities()), +std::make_pair(" OpCapability Shader" " OpMemoryModel Logical GLSL450" " OpEntryPoint Vertex %func \"shader\"" + - string(kVoidFVoid), AllCapabilities()), -make_pair(" OpCapability Kernel" + std::string(kVoidFVoid), AllCapabilities()), +std::make_pair(" OpCapability Kernel" " OpMemoryModel Logical OpenCL" " OpEntryPoint Kernel %func \"compute\"" + - string(kVoidFVoid), AllCapabilities()), -make_pair(" OpCapability Shader" + std::string(kVoidFVoid), AllCapabilities()), +std::make_pair(" OpCapability Shader" " OpMemoryModel Physical32 Simple" " OpEntryPoint Vertex %func \"shader\"" + - string(kVoidFVoid), AddressesDependencies()), -make_pair(" OpCapability Shader" + std::string(kVoidFVoid), AddressesDependencies()), +std::make_pair(" OpCapability Shader" " OpMemoryModel Physical32 GLSL450" " OpEntryPoint Vertex %func \"shader\"" + - string(kVoidFVoid), AddressesDependencies()), -make_pair(" OpCapability Kernel" + std::string(kVoidFVoid), AddressesDependencies()), +std::make_pair(" OpCapability Kernel" " OpMemoryModel Physical32 OpenCL" " OpEntryPoint Kernel %func \"compute\"" + - string(kVoidFVoid), AddressesDependencies()), -make_pair(" OpCapability Shader" + std::string(kVoidFVoid), AddressesDependencies()), +std::make_pair(" OpCapability Shader" " OpMemoryModel Physical64 Simple" " OpEntryPoint Vertex %func \"shader\"" + - string(kVoidFVoid), AddressesDependencies()), -make_pair(" OpCapability Shader" + std::string(kVoidFVoid), AddressesDependencies()), +std::make_pair(" OpCapability Shader" " OpMemoryModel Physical64 GLSL450" " OpEntryPoint Vertex %func \"shader\"" + - string(kVoidFVoid), AddressesDependencies()), -make_pair(" OpCapability Kernel" + std::string(kVoidFVoid), AddressesDependencies()), +std::make_pair(" OpCapability Kernel" " OpMemoryModel Physical64 OpenCL" " OpEntryPoint Kernel %func \"compute\"" + - string(kVoidFVoid), AddressesDependencies()) + std::string(kVoidFVoid), AddressesDependencies()) )),); INSTANTIATE_TEST_CASE_P(ExecutionMode, ValidateCapability, Combine( ValuesIn(AllCapabilities()), Values( -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Geometry %func \"shader\" " "OpExecutionMode %func Invocations 42" + - string(kVoidFVoid), GeometryDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint TessellationControl %func \"shader\" " "OpExecutionMode %func SpacingEqual" + - string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint TessellationControl %func \"shader\" " "OpExecutionMode %func SpacingFractionalEven" + - string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint TessellationControl %func \"shader\" " "OpExecutionMode %func SpacingFractionalOdd" + - string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint TessellationControl %func \"shader\" " "OpExecutionMode %func VertexOrderCw" + - string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint TessellationControl %func \"shader\" " "OpExecutionMode %func VertexOrderCcw" + - string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Vertex %func \"shader\" " "OpExecutionMode %func PixelCenterInteger" + - string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Vertex %func \"shader\" " "OpExecutionMode %func OriginUpperLeft" + - string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Vertex %func \"shader\" " "OpExecutionMode %func OriginLowerLeft" + - string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Vertex %func \"shader\" " "OpExecutionMode %func EarlyFragmentTests" + - string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint TessellationControl %func \"shader\" " "OpExecutionMode %func PointMode" + - string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Vertex %func \"shader\" " "OpExecutionMode %func Xfb" + - string(kVoidFVoid), vector{"TransformFeedback"}), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), std::vector{"TransformFeedback"}), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Vertex %func \"shader\" " "OpExecutionMode %func DepthReplacing" + - string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Vertex %func \"shader\" " "OpExecutionMode %func DepthGreater" + - string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Vertex %func \"shader\" " "OpExecutionMode %func DepthLess" + - string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Vertex %func \"shader\" " "OpExecutionMode %func DepthUnchanged" + - string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"shader\" " "OpExecutionMode %func LocalSize 42 42 42" + - string(kVoidFVoid), AllCapabilities()), -make_pair(string(kGLSL450MemoryModel) + + std::string(kVoidFVoid), AllCapabilities()), +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Kernel %func \"shader\" " "OpExecutionMode %func LocalSizeHint 42 42 42" + - string(kVoidFVoid), KernelDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), KernelDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Geometry %func \"shader\" " "OpExecutionMode %func InputPoints" + - string(kVoidFVoid), GeometryDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Geometry %func \"shader\" " "OpExecutionMode %func InputLines" + - string(kVoidFVoid), GeometryDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Geometry %func \"shader\" " "OpExecutionMode %func InputLinesAdjacency" + - string(kVoidFVoid), GeometryDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Geometry %func \"shader\" " "OpExecutionMode %func Triangles" + - string(kVoidFVoid), GeometryDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint TessellationControl %func \"shader\" " "OpExecutionMode %func Triangles" + - string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Geometry %func \"shader\" " "OpExecutionMode %func InputTrianglesAdjacency" + - string(kVoidFVoid), GeometryDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint TessellationControl %func \"shader\" " "OpExecutionMode %func Quads" + - string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint TessellationControl %func \"shader\" " "OpExecutionMode %func Isolines" + - string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Geometry %func \"shader\" " "OpExecutionMode %func OutputVertices 42" + - string(kVoidFVoid), GeometryDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint TessellationControl %func \"shader\" " "OpExecutionMode %func OutputVertices 42" + - string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Geometry %func \"shader\" " "OpExecutionMode %func OutputPoints" + - string(kVoidFVoid), GeometryDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Geometry %func \"shader\" " "OpExecutionMode %func OutputLineStrip" + - string(kVoidFVoid), GeometryDependencies()), -make_pair(string(kOpenCLMemoryModel) + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Geometry %func \"shader\" " "OpExecutionMode %func OutputTriangleStrip" + - string(kVoidFVoid), GeometryDependencies()), -make_pair(string(kGLSL450MemoryModel) + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Kernel %func \"shader\" " "OpExecutionMode %func VecTypeHint 2" + - string(kVoidFVoid), KernelDependencies()), -make_pair(string(kGLSL450MemoryModel) + + std::string(kVoidFVoid), KernelDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Kernel %func \"shader\" " "OpExecutionMode %func ContractionOff" + - string(kVoidFVoid), KernelDependencies()))),); + std::string(kVoidFVoid), KernelDependencies()))),); // clang-format on INSTANTIATE_TEST_CASE_P( ExecutionModeV11, ValidateCapabilityV11, Combine(ValuesIn(AllCapabilities()), - Values(make_pair(string(kOpenCLMemoryModel) + - "OpEntryPoint Kernel %func \"shader\" " - "OpExecutionMode %func SubgroupSize 1" + - string(kVoidFVoid), - vector{"SubgroupDispatch"}), - make_pair( - string(kOpenCLMemoryModel) + + Values(std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"shader\" " + "OpExecutionMode %func SubgroupSize 1" + + std::string(kVoidFVoid), + std::vector{"SubgroupDispatch"}), + std::make_pair( + std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"shader\" " "OpExecutionMode %func SubgroupsPerWorkgroup 65535" + - string(kVoidFVoid), - vector{"SubgroupDispatch"}))), ); + std::string(kVoidFVoid), + std::vector{"SubgroupDispatch"}))), ); // clang-format off INSTANTIATE_TEST_CASE_P(StorageClass, ValidateCapability, Combine( ValuesIn(AllCapabilities()), Values( -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + " OpEntryPoint Vertex %func \"shader\"" + " %intt = OpTypeInt 32 0\n" " %ptrt = OpTypePointer UniformConstant %intt\n" - " %var = OpVariable %ptrt UniformConstant\n" + string(kVoidFVoid), + " %var = OpVariable %ptrt UniformConstant\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + " OpEntryPoint Kernel %func \"compute\"" + " %intt = OpTypeInt 32 0\n" " %ptrt = OpTypePointer Input %intt" - " %var = OpVariable %ptrt Input\n" + string(kVoidFVoid), + " %var = OpVariable %ptrt Input\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + " OpEntryPoint Vertex %func \"shader\"" + " %intt = OpTypeInt 32 0\n" " %ptrt = OpTypePointer Uniform %intt\n" - " %var = OpVariable %ptrt Uniform\n" + string(kVoidFVoid), + " %var = OpVariable %ptrt Uniform\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + " OpEntryPoint Vertex %func \"shader\"" + " %intt = OpTypeInt 32 0\n" " %ptrt = OpTypePointer Output %intt\n" - " %var = OpVariable %ptrt Output\n" + string(kVoidFVoid), + " %var = OpVariable %ptrt Output\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + " OpEntryPoint Vertex %func \"shader\"" + " %intt = OpTypeInt 32 0\n" " %ptrt = OpTypePointer Workgroup %intt\n" - " %var = OpVariable %ptrt Workgroup\n" + string(kVoidFVoid), + " %var = OpVariable %ptrt Workgroup\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + " OpEntryPoint Vertex %func \"shader\"" + " %intt = OpTypeInt 32 0\n" " %ptrt = OpTypePointer CrossWorkgroup %intt\n" - " %var = OpVariable %ptrt CrossWorkgroup\n" + string(kVoidFVoid), + " %var = OpVariable %ptrt CrossWorkgroup\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + " OpEntryPoint Kernel %func \"compute\"" + " %intt = OpTypeInt 32 0\n" " %ptrt = OpTypePointer Private %intt\n" - " %var = OpVariable %ptrt Private\n" + string(kVoidFVoid), + " %var = OpVariable %ptrt Private\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + " OpEntryPoint Kernel %func \"compute\"" + " %intt = OpTypeInt 32 0\n" " %ptrt = OpTypePointer PushConstant %intt\n" - " %var = OpVariable %ptrt PushConstant\n" + string(kVoidFVoid), + " %var = OpVariable %ptrt PushConstant\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + " OpEntryPoint Vertex %func \"shader\"" + " %intt = OpTypeInt 32 0\n" " %ptrt = OpTypePointer AtomicCounter %intt\n" - " %var = OpVariable %ptrt AtomicCounter\n" + string(kVoidFVoid), - vector{"AtomicStorage"}), -make_pair(string(kGLSL450MemoryModel) + + " %var = OpVariable %ptrt AtomicCounter\n" + std::string(kVoidFVoid), + std::vector{"AtomicStorage"}), +std::make_pair(std::string(kGLSL450MemoryModel) + " OpEntryPoint Vertex %func \"shader\"" + " %intt = OpTypeInt 32 0\n" " %ptrt = OpTypePointer Image %intt\n" - " %var = OpVariable %ptrt Image\n" + string(kVoidFVoid), + " %var = OpVariable %ptrt Image\n" + std::string(kVoidFVoid), AllCapabilities()) )),); @@ -761,48 +875,48 @@ INSTANTIATE_TEST_CASE_P(Dim, ValidateCapability, Combine( ValuesIn(AllCapabilities()), Values( -make_pair(" OpCapability ImageBasic" + - string(kOpenCLMemoryModel) + - string(" OpEntryPoint Kernel %func \"compute\"") + +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + " %voidt = OpTypeVoid" - " %imgt = OpTypeImage %voidt 1D 0 0 0 0 Unknown" + string(kVoidFVoid2), + " %imgt = OpTypeImage %voidt 1D 0 0 0 0 Unknown" + std::string(kVoidFVoid2), Sampled1DDependencies()), -make_pair(" OpCapability ImageBasic" + - string(kOpenCLMemoryModel) + - string(" OpEntryPoint Kernel %func \"compute\"") + +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + " %voidt = OpTypeVoid" - " %imgt = OpTypeImage %voidt 2D 0 0 0 0 Unknown" + string(kVoidFVoid2), + " %imgt = OpTypeImage %voidt 2D 0 0 0 0 Unknown" + std::string(kVoidFVoid2), AllCapabilities()), -make_pair(" OpCapability ImageBasic" + - string(kOpenCLMemoryModel) + - string(" OpEntryPoint Kernel %func \"compute\"") + +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + " %voidt = OpTypeVoid" - " %imgt = OpTypeImage %voidt 3D 0 0 0 0 Unknown" + string(kVoidFVoid2), + " %imgt = OpTypeImage %voidt 3D 0 0 0 0 Unknown" + std::string(kVoidFVoid2), AllCapabilities()), -make_pair(" OpCapability ImageBasic" + - string(kOpenCLMemoryModel) + - string(" OpEntryPoint Kernel %func \"compute\"") + +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + " %voidt = OpTypeVoid" - " %imgt = OpTypeImage %voidt Cube 0 0 0 0 Unknown" + string(kVoidFVoid2), + " %imgt = OpTypeImage %voidt Cube 0 0 0 0 Unknown" + std::string(kVoidFVoid2), ShaderDependencies()), -make_pair(" OpCapability ImageBasic" + - string(kOpenCLMemoryModel) + - string(" OpEntryPoint Kernel %func \"compute\"") + +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + " %voidt = OpTypeVoid" - " %imgt = OpTypeImage %voidt Rect 0 0 0 0 Unknown" + string(kVoidFVoid2), + " %imgt = OpTypeImage %voidt Rect 0 0 0 0 Unknown" + std::string(kVoidFVoid2), SampledRectDependencies()), -make_pair(" OpCapability ImageBasic" + - string(kOpenCLMemoryModel) + - string(" OpEntryPoint Kernel %func \"compute\"") + +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + " %voidt = OpTypeVoid" - " %imgt = OpTypeImage %voidt Buffer 0 0 0 0 Unknown" + string(kVoidFVoid2), + " %imgt = OpTypeImage %voidt Buffer 0 0 0 0 Unknown" + std::string(kVoidFVoid2), SampledBufferDependencies()), -make_pair(" OpCapability ImageBasic" + - string(kOpenCLMemoryModel) + - string(" OpEntryPoint Kernel %func \"compute\"") + +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + " %voidt = OpTypeVoid" - " %imgt = OpTypeImage %voidt SubpassData 0 0 0 2 Unknown" + string(kVoidFVoid2), - vector{"InputAttachment"}) + " %imgt = OpTypeImage %voidt SubpassData 0 0 0 2 Unknown" + std::string(kVoidFVoid2), + std::vector{"InputAttachment"}) )),); // NOTE: All Sampler Address Modes require kernel capabilities but the @@ -811,518 +925,521 @@ INSTANTIATE_TEST_CASE_P(SamplerAddressingMode, ValidateCapability, Combine( ValuesIn(AllCapabilities()), Values( -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + " OpEntryPoint Vertex %func \"shader\"" " %samplert = OpTypeSampler" " %sampler = OpConstantSampler %samplert None 1 Nearest" + - string(kVoidFVoid), - vector{"LiteralSampler"}), -make_pair(string(kGLSL450MemoryModel) + + std::string(kVoidFVoid), + std::vector{"LiteralSampler"}), +std::make_pair(std::string(kGLSL450MemoryModel) + " OpEntryPoint Vertex %func \"shader\"" " %samplert = OpTypeSampler" " %sampler = OpConstantSampler %samplert ClampToEdge 1 Nearest" + - string(kVoidFVoid), - vector{"LiteralSampler"}), -make_pair(string(kGLSL450MemoryModel) + + std::string(kVoidFVoid), + std::vector{"LiteralSampler"}), +std::make_pair(std::string(kGLSL450MemoryModel) + " OpEntryPoint Vertex %func \"shader\"" " %samplert = OpTypeSampler" " %sampler = OpConstantSampler %samplert Clamp 1 Nearest" + - string(kVoidFVoid), - vector{"LiteralSampler"}), -make_pair(string(kGLSL450MemoryModel) + + std::string(kVoidFVoid), + std::vector{"LiteralSampler"}), +std::make_pair(std::string(kGLSL450MemoryModel) + " OpEntryPoint Vertex %func \"shader\"" " %samplert = OpTypeSampler" " %sampler = OpConstantSampler %samplert Repeat 1 Nearest" + - string(kVoidFVoid), - vector{"LiteralSampler"}), -make_pair(string(kGLSL450MemoryModel) + + std::string(kVoidFVoid), + std::vector{"LiteralSampler"}), +std::make_pair(std::string(kGLSL450MemoryModel) + " OpEntryPoint Vertex %func \"shader\"" " %samplert = OpTypeSampler" " %sampler = OpConstantSampler %samplert RepeatMirrored 1 Nearest" + - string(kVoidFVoid), - vector{"LiteralSampler"}) + std::string(kVoidFVoid), + std::vector{"LiteralSampler"}) )),); -//TODO(umar): Sampler Filter Mode -//TODO(umar): Image Format -//TODO(umar): Image Channel Order -//TODO(umar): Image Channel Data Type -//TODO(umar): Image Operands -//TODO(umar): FP Fast Math Mode -//TODO(umar): FP Rounding Mode -//TODO(umar): Linkage Type -//TODO(umar): Access Qualifier -//TODO(umar): Function Parameter Attribute +// TODO(umar): Sampler Filter Mode +// TODO(umar): Image Format +// TODO(umar): Image Channel Order +// TODO(umar): Image Channel Data Type +// TODO(umar): Image Operands +// TODO(umar): FP Fast Math Mode +// TODO(umar): FP Rounding Mode +// TODO(umar): Linkage Type +// TODO(umar): Access Qualifier +// TODO(umar): Function Parameter Attribute INSTANTIATE_TEST_CASE_P(Decoration, ValidateCapability, Combine( ValuesIn(AllCapabilities()), Values( -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt RelaxedPrecision\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Block\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt BufferBlock\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt RowMajor\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), MatrixDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt ColMajor\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), MatrixDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt ArrayStride 1\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt MatrixStride 1\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), MatrixDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt GLSLShared\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt GLSLPacked\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" "OpDecorate %intt CPacked\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), KernelDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt NoPerspective\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Flat\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Patch\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Centroid\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Sample\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), - vector{"SampleRateShading"}), -make_pair(string(kOpenCLMemoryModel) + + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"SampleRateShading"}), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Invariant\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Restrict\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Aliased\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Volatile\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" "OpDecorate %intt Constant\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), KernelDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Coherent\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt NonWritable\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt NonReadable\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Uniform\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" "OpDecorate %intt SaturatedConversion\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), KernelDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Stream 0\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), - vector{"GeometryStreams"}), -make_pair(string(kOpenCLMemoryModel) + + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"GeometryStreams"}), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Location 0\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Component 0\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Index 0\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Binding 0\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt DescriptorSet 0\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt Offset 0\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt XfbBuffer 0\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), - vector{"TransformFeedback"}), -make_pair(string(kOpenCLMemoryModel) + + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"TransformFeedback"}), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt XfbStride 0\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), - vector{"TransformFeedback"}), -make_pair(string(kGLSL450MemoryModel) + + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"TransformFeedback"}), +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" "OpDecorate %intt FuncParamAttr Zext\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), KernelDependencies()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" "OpDecorate %intt FPFastMathMode Fast\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), KernelDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt LinkageAttributes \"other\" Import\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), - vector{"Linkage"}), -make_pair(string(kOpenCLMemoryModel) + + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"Linkage"}), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt NoContraction\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" "OpDecorate %intt InputAttachmentIndex 0\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), - vector{"InputAttachment"}), -make_pair(string(kGLSL450MemoryModel) + + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"InputAttachment"}), +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" "OpDecorate %intt Alignment 4\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), KernelDependencies()) )),); // clang-format on INSTANTIATE_TEST_CASE_P( DecorationSpecId, ValidateCapability, - Combine(ValuesIn(AllSpirV10Capabilities()), - Values(make_pair(string(kOpenCLMemoryModel) + - "OpEntryPoint Vertex %func \"shader\" \n" + - "OpDecorate %1 SpecId 1\n" - "%intt = OpTypeInt 32 0\n" - "%1 = OpSpecConstant %intt 0\n" + - string(kVoidFVoid), - ShaderDependencies()))), ); + Combine( + ValuesIn(AllSpirV10Capabilities()), + Values(std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %1 SpecId 1\n" + "%intt = OpTypeInt 32 0\n" + "%1 = OpSpecConstant %intt 0\n" + + std::string(kVoidFVoid), + ShaderDependencies()))), ); INSTANTIATE_TEST_CASE_P( DecorationV11, ValidateCapabilityV11, Combine(ValuesIn(AllCapabilities()), - Values(make_pair(string(kOpenCLMemoryModel) + - "OpEntryPoint Kernel %func \"compute\" \n" - "OpDecorate %p MaxByteOffset 0 " - "%i32 = OpTypeInt 32 0 " - "%pi32 = OpTypePointer Workgroup %i32 " - "%p = OpVariable %pi32 Workgroup " + - string(kVoidFVoid), - AddressesDependencies()), + Values(std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %p MaxByteOffset 0 " + "%i32 = OpTypeInt 32 0 " + "%pi32 = OpTypePointer Workgroup %i32 " + "%p = OpVariable %pi32 Workgroup " + + std::string(kVoidFVoid), + AddressesDependencies()), // Trying to test OpDecorate here, but if this fails due to // incorrect OpMemoryModel validation, that must also be // fixed. - make_pair(string("OpMemoryModel Logical OpenCL " - "OpEntryPoint Kernel %func \"compute\" \n" - "OpDecorate %1 SpecId 1 " - "%intt = OpTypeInt 32 0 " - "%1 = OpSpecConstant %intt 0") + - string(kVoidFVoid), - KernelDependencies()), - make_pair(string("OpMemoryModel Logical Simple " - "OpEntryPoint Vertex %func \"shader\" \n" - "OpDecorate %1 SpecId 1 " - "%intt = OpTypeInt 32 0 " - "%1 = OpSpecConstant %intt 0") + - string(kVoidFVoid), - ShaderDependencies()))), ); + std::make_pair( + std::string("OpMemoryModel Logical OpenCL " + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %1 SpecId 1 " + "%intt = OpTypeInt 32 0 " + "%1 = OpSpecConstant %intt 0") + + std::string(kVoidFVoid), + KernelDependencies()), + std::make_pair( + std::string("OpMemoryModel Logical Simple " + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %1 SpecId 1 " + "%intt = OpTypeInt 32 0 " + "%1 = OpSpecConstant %intt 0") + + std::string(kVoidFVoid), + ShaderDependencies()))), ); // clang-format off INSTANTIATE_TEST_CASE_P(BuiltIn, ValidateCapability, Combine( ValuesIn(AllCapabilities()), Values( -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn Position\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), // Just mentioning PointSize, ClipDistance, or CullDistance as a BuiltIn does // not trigger the requirement for the associated capability. // See https://github.com/KhronosGroup/SPIRV-Tools/issues/365 -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn PointSize\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn ClipDistance\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn CullDistance\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn VertexId\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn InstanceId\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn PrimitiveId\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), GeometryTessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn InvocationId\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), GeometryTessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn Layer\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), GeometryDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn ViewportIndex\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), - vector{"MultiViewport"}), -make_pair(string(kOpenCLMemoryModel) + + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"MultiViewport"}), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn TessLevelOuter\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn TessLevelInner\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn TessCoord\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn PatchVertices\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), TessellationDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn FragCoord\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn PointCoord\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn FrontFacing\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn SampleId\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), - vector{"SampleRateShading"}), -make_pair(string(kOpenCLMemoryModel) + + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"SampleRateShading"}), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn SamplePosition\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), - vector{"SampleRateShading"}), -make_pair(string(kOpenCLMemoryModel) + + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"SampleRateShading"}), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn SampleMask\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn FragDepth\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn HelperInvocation\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn VertexIndex\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn InstanceIndex\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn NumWorkgroups\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn WorkgroupSize\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn WorkgroupId\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn LocalInvocationId\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn GlobalInvocationId\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn LocalInvocationIndex\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllCapabilities()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn WorkDim\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), KernelDependencies()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn GlobalSize\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), KernelDependencies()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn EnqueuedWorkgroupSize\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), KernelDependencies()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn GlobalOffset\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), KernelDependencies()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn GlobalLinearId\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), KernelDependencies()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn SubgroupSize\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), - KernelDependencies()), -make_pair(string(kGLSL450MemoryModel) + + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelAndGroupNonUniformDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn SubgroupMaxSize\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), KernelDependencies()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn NumSubgroups\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), - KernelDependencies()), -make_pair(string(kGLSL450MemoryModel) + + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelAndGroupNonUniformDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn NumEnqueuedSubgroups\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), KernelDependencies()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn SubgroupId\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), - KernelDependencies()), -make_pair(string(kGLSL450MemoryModel) + + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelAndGroupNonUniformDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn SubgroupLocalInvocationId\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), - KernelDependencies()), -make_pair(string(kOpenCLMemoryModel) + + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelAndGroupNonUniformDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn VertexIndex\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()), -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BuiltIn InstanceIndex\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), ShaderDependencies()) )),); @@ -1335,31 +1452,31 @@ INSTANTIATE_TEST_CASE_P(BuiltIn, ValidateCapabilityVulkan10, // All capabilities to try. ValuesIn(AllSpirV10Capabilities()), Values( -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" "OpMemberDecorate %block 0 BuiltIn PointSize\n" "%f32 = OpTypeFloat 32\n" "%block = OpTypeStruct %f32\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), // Capabilities which should succeed. AllVulkan10Capabilities()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" "OpMemberDecorate %block 0 BuiltIn ClipDistance\n" "%f32 = OpTypeFloat 32\n" "%intt = OpTypeInt 32 0\n" "%intt_4 = OpConstant %intt 4\n" "%f32arr4 = OpTypeArray %f32 %intt_4\n" - "%block = OpTypeStruct %f32arr4\n" + string(kVoidFVoid), + "%block = OpTypeStruct %f32arr4\n" + std::string(kVoidFVoid), AllVulkan10Capabilities()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" "OpMemberDecorate %block 0 BuiltIn CullDistance\n" "%f32 = OpTypeFloat 32\n" "%intt = OpTypeInt 32 0\n" "%intt_4 = OpConstant %intt 4\n" "%f32arr4 = OpTypeArray %f32 %intt_4\n" - "%block = OpTypeStruct %f32arr4\n" + string(kVoidFVoid), + "%block = OpTypeStruct %f32arr4\n" + std::string(kVoidFVoid), AllVulkan10Capabilities()) )),); @@ -1368,23 +1485,40 @@ INSTANTIATE_TEST_CASE_P(BuiltIn, ValidateCapabilityOpenGL40, // OpenGL 4.0 is based on SPIR-V 1.0 ValuesIn(AllSpirV10Capabilities()), Values( -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn PointSize\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllSpirV10Capabilities()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn ClipDistance\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllSpirV10Capabilities()), -make_pair(string(kGLSL450MemoryModel) + +std::make_pair(std::string(kGLSL450MemoryModel) + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt BuiltIn CullDistance\n" - "%intt = OpTypeInt 32 0\n" + string(kVoidFVoid), + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), AllSpirV10Capabilities()) )),); +INSTANTIATE_TEST_CASE_P(Capabilities, ValidateCapabilityVulkan11, + Combine( + // All capabilities to try. + ValuesIn(AllCapabilities()), + Values( +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn PointSize\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllVulkan11Capabilities()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn CullDistance\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllVulkan11Capabilities()) +)),); + // TODO(umar): Selection Control // TODO(umar): Loop Control // TODO(umar): Function Control @@ -1399,11 +1533,11 @@ INSTANTIATE_TEST_CASE_P(MatrixOp, ValidateCapability, Combine( ValuesIn(AllCapabilities()), Values( -make_pair(string(kOpenCLMemoryModel) + +std::make_pair(std::string(kOpenCLMemoryModel) + "OpEntryPoint Kernel %func \"compute\" \n" + "%f32 = OpTypeFloat 32\n" "%vec3 = OpTypeVector %f32 3\n" - "%mat33 = OpTypeMatrix %vec3 3\n" + string(kVoidFVoid), + "%mat33 = OpTypeMatrix %vec3 3\n" + std::string(kVoidFVoid), MatrixDependencies()))),); // clang-format on @@ -1415,7 +1549,7 @@ make_pair(string(kOpenCLMemoryModel) + // the image-operands part. The assembly defines constants %fzero and %izero // that can be used for operands where IDs are required. The assembly is valid, // apart from not declaring any capabilities required by the operands. -string ImageOperandsTemplate(const string& operands) { +string ImageOperandsTemplate(const std::string& operands) { ostringstream ss; // clang-format off ss << R"( @@ -1449,34 +1583,66 @@ INSTANTIATE_TEST_CASE_P( TwoImageOperandsMask, ValidateCapability, Combine( ValuesIn(AllCapabilities()), - Values(make_pair(ImageOperandsTemplate("Bias|Lod %fzero %fzero"), + Values(std::make_pair(ImageOperandsTemplate("Bias|Lod %fzero %fzero"), ShaderDependencies()), - make_pair(ImageOperandsTemplate("Lod|Offset %fzero %izero"), - vector{"ImageGatherExtended"}), - make_pair(ImageOperandsTemplate("Sample|MinLod %izero %fzero"), - vector{"MinLod"}), - make_pair(ImageOperandsTemplate("Lod|Sample %fzero %izero"), + std::make_pair(ImageOperandsTemplate("Lod|Offset %fzero %izero"), + std::vector{"ImageGatherExtended"}), + std::make_pair(ImageOperandsTemplate("Sample|MinLod %izero %fzero"), + std::vector{"MinLod"}), + std::make_pair(ImageOperandsTemplate("Lod|Sample %fzero %izero"), AllCapabilities()))), ); #endif // TODO(umar): Instruction capability checks -// True if capability exists in env. +spv_result_t spvCoreOperandTableNameLookup(spv_target_env env, + const spv_operand_table table, + const spv_operand_type_t type, + const char* name, + const size_t nameLength) { + if (!table) return SPV_ERROR_INVALID_TABLE; + if (!name) return SPV_ERROR_INVALID_POINTER; + + for (uint64_t typeIndex = 0; typeIndex < table->count; ++typeIndex) { + const auto& group = table->types[typeIndex]; + if (type != group.type) continue; + for (uint64_t index = 0; index < group.count; ++index) { + const auto& entry = group.entries[index]; + // Check for min version only. + if (spvVersionForTargetEnv(env) >= entry.minVersion && + nameLength == strlen(entry.name) && + !strncmp(entry.name, name, nameLength)) { + return SPV_SUCCESS; + } + } + } + + return SPV_ERROR_INVALID_LOOKUP; +} + +// True if capability exists in core spec of env. bool Exists(const std::string& capability, spv_target_env env) { - spv_operand_desc dummy; - return SPV_SUCCESS == libspirv::AssemblyGrammar(ScopedContext(env).context) - .lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, - capability.c_str(), - capability.size(), &dummy); + ScopedContext sc(env); + return SPV_SUCCESS == + spvCoreOperandTableNameLookup(env, sc.context->operand_table, + SPV_OPERAND_TYPE_CAPABILITY, + capability.c_str(), capability.size()); } TEST_P(ValidateCapability, Capability) { - const string capability = Capability(GetParam()); - spv_target_env env = - (capability.empty() || Exists(capability, SPV_ENV_UNIVERSAL_1_0)) - ? SPV_ENV_UNIVERSAL_1_0 - : SPV_ENV_UNIVERSAL_1_1; - const string test_code = MakeAssembly(GetParam()); + const std::string capability = Capability(GetParam()); + spv_target_env env = SPV_ENV_UNIVERSAL_1_0; + if (!capability.empty()) { + if (Exists(capability, SPV_ENV_UNIVERSAL_1_0)) + env = SPV_ENV_UNIVERSAL_1_0; + else if (Exists(capability, SPV_ENV_UNIVERSAL_1_1)) + env = SPV_ENV_UNIVERSAL_1_1; + else if (Exists(capability, SPV_ENV_UNIVERSAL_1_2)) + env = SPV_ENV_UNIVERSAL_1_2; + else + env = SPV_ENV_UNIVERSAL_1_3; + } + const std::string test_code = MakeAssembly(GetParam()); CompileSuccessfully(test_code, env); ASSERT_EQ(ExpectedResult(GetParam()), ValidateInstructions(env)) << "target env: " << spvTargetEnvDescription(env) << "\ntest code:\n" @@ -1484,27 +1650,47 @@ TEST_P(ValidateCapability, Capability) { } TEST_P(ValidateCapabilityV11, Capability) { - const string test_code = MakeAssembly(GetParam()); - CompileSuccessfully(test_code, SPV_ENV_UNIVERSAL_1_1); - ASSERT_EQ(ExpectedResult(GetParam()), - ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)) - << test_code; + const std::string capability = Capability(GetParam()); + if (Exists(capability, SPV_ENV_UNIVERSAL_1_1)) { + const std::string test_code = MakeAssembly(GetParam()); + CompileSuccessfully(test_code, SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(ExpectedResult(GetParam()), + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)) + << test_code; + } } TEST_P(ValidateCapabilityVulkan10, Capability) { - const string test_code = MakeAssembly(GetParam()); - CompileSuccessfully(test_code, SPV_ENV_VULKAN_1_0); - ASSERT_EQ(ExpectedResult(GetParam()), - ValidateInstructions(SPV_ENV_VULKAN_1_0)) - << test_code; + const std::string capability = Capability(GetParam()); + if (Exists(capability, SPV_ENV_VULKAN_1_0)) { + const std::string test_code = MakeAssembly(GetParam()); + CompileSuccessfully(test_code, SPV_ENV_VULKAN_1_0); + ASSERT_EQ(ExpectedResult(GetParam()), + ValidateInstructions(SPV_ENV_VULKAN_1_0)) + << test_code; + } +} + +TEST_P(ValidateCapabilityVulkan11, Capability) { + const std::string capability = Capability(GetParam()); + if (Exists(capability, SPV_ENV_VULKAN_1_1)) { + const std::string test_code = MakeAssembly(GetParam()); + CompileSuccessfully(test_code, SPV_ENV_VULKAN_1_1); + ASSERT_EQ(ExpectedResult(GetParam()), + ValidateInstructions(SPV_ENV_VULKAN_1_1)) + << test_code; + } } TEST_P(ValidateCapabilityOpenGL40, Capability) { - const string test_code = MakeAssembly(GetParam()); - CompileSuccessfully(test_code, SPV_ENV_OPENGL_4_0); - ASSERT_EQ(ExpectedResult(GetParam()), - ValidateInstructions(SPV_ENV_OPENGL_4_0)) - << test_code; + const std::string capability = Capability(GetParam()); + if (Exists(capability, SPV_ENV_OPENGL_4_0)) { + const std::string test_code = MakeAssembly(GetParam()); + CompileSuccessfully(test_code, SPV_ENV_OPENGL_4_0); + ASSERT_EQ(ExpectedResult(GetParam()), + ValidateInstructions(SPV_ENV_OPENGL_4_0)) + << test_code; + } } TEST_F(ValidateCapability, SemanticsIdIsAnIdNotALiteral) { @@ -1611,7 +1797,7 @@ OpEntryPoint Vertex %func "shader" OpMemberDecorate %block 0 BuiltIn PointSize %f32 = OpTypeFloat 32 %block = OpTypeStruct %f32 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_0); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); @@ -1625,7 +1811,7 @@ OpMemoryModel Logical GLSL450 OpEntryPoint Vertex %func "shader" OpDecorate %intt BuiltIn PointSize %intt = OpTypeInt 32 0 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_0); EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, @@ -1661,7 +1847,7 @@ OpCapability ImageBasic OpCapability Sampled1D OpMemoryModel Physical64 OpenCL %u32 = OpTypeInt 32 0 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_OPENCL_1_2); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_1_2)); @@ -1675,7 +1861,7 @@ OpCapability Linkage OpCapability Sampled1D OpMemoryModel Physical64 OpenCL %u32 = OpTypeInt 32 0 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_OPENCL_1_2); EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, @@ -1713,7 +1899,7 @@ OpCapability ImageBasic OpCapability Sampled1D OpMemoryModel Physical64 OpenCL %u32 = OpTypeInt 32 0 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_1_2); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_EMBEDDED_1_2)); @@ -1727,7 +1913,7 @@ OpCapability Linkage OpCapability Sampled1D OpMemoryModel Physical64 OpenCL %u32 = OpTypeInt 32 0 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_1_2); EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, @@ -1777,7 +1963,7 @@ OpCapability ImageBasic OpCapability Sampled1D OpMemoryModel Physical64 OpenCL %u32 = OpTypeInt 32 0 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_OPENCL_2_0); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_2_0)); @@ -1791,7 +1977,7 @@ OpCapability Linkage OpCapability Sampled1D OpMemoryModel Physical64 OpenCL %u32 = OpTypeInt 32 0 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_OPENCL_2_0); EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, @@ -1827,7 +2013,7 @@ OpCapability ImageBasic OpCapability Sampled1D OpMemoryModel Physical64 OpenCL %u32 = OpTypeInt 32 0 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_2_0); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_EMBEDDED_2_0)); @@ -1841,7 +2027,7 @@ OpCapability Linkage OpCapability Sampled1D OpMemoryModel Physical64 OpenCL %u32 = OpTypeInt 32 0 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_2_0); EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, @@ -1890,7 +2076,7 @@ OpCapability ImageBasic OpCapability Sampled1D OpMemoryModel Physical64 OpenCL %u32 = OpTypeInt 32 0 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_OPENCL_2_2); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_2_2)); @@ -1904,7 +2090,7 @@ OpCapability Linkage OpCapability Sampled1D OpMemoryModel Physical64 OpenCL %u32 = OpTypeInt 32 0 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_OPENCL_2_2); EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, @@ -1942,7 +2128,7 @@ OpCapability ImageBasic OpCapability Sampled1D OpMemoryModel Physical64 OpenCL %u32 = OpTypeInt 32 0 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_2_2); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_EMBEDDED_2_2)); @@ -1956,7 +2142,7 @@ OpCapability Linkage OpCapability Sampled1D OpMemoryModel Physical64 OpenCL %u32 = OpTypeInt 32 0 -)" + string(kVoidFVoid); +)" + std::string(kVoidFVoid); CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_2_2); EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, @@ -1966,4 +2152,150 @@ OpMemoryModel Physical64 OpenCL "Embedded Profile")); } +// Three tests to check enablement of an enum (a decoration) which is not +// in core, and is directly enabled by a capability, but not directly enabled +// by an extension. See https://github.com/KhronosGroup/SPIRV-Tools/issues/1596 + +TEST_F(ValidateCapability, DecorationFromExtensionMissingEnabledByCapability) { + // Decoration ViewportRelativeNV is enabled by ShaderViewportMaskNV, which in + // turn is enabled by SPV_NV_viewport_array2. + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical Simple +OpDecorate %void ViewportRelativeNV +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 2 of Decorate requires one of these " + "capabilities: ShaderViewportMaskNV")); +} + +TEST_F(ValidateCapability, CapabilityEnabledByMissingExtension) { + // Capability ShaderViewportMaskNV is enabled by SPV_NV_viewport_array2. + const std::string spirv = R"( +OpCapability Shader +OpCapability ShaderViewportMaskNV +OpMemoryModel Logical Simple +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(SPV_ERROR_MISSING_EXTENSION, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("operand 5255 requires one of these extensions: " + "SPV_NV_viewport_array2")); +} + +TEST_F(ValidateCapability, + DecorationEnabledByCapabilityEnabledByPresentExtension) { + // Decoration ViewportRelativeNV is enabled by ShaderViewportMaskNV, which in + // turn is enabled by SPV_NV_viewport_array2. + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability ShaderViewportMaskNV +OpExtension "SPV_NV_viewport_array2" +OpMemoryModel Logical Simple +OpDecorate %void ViewportRelativeNV +%void = OpTypeVoid +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_0)) + << getDiagnosticString(); +} + +// Three tests to check enablement of an instruction which is not in core, and +// is directly enabled by a capability, but not directly enabled by an +// extension. See https://github.com/KhronosGroup/SPIRV-Tools/issues/1624 +// Instruction OpSubgroupShuffleINTEL is enabled by SubgroupShuffleINTEL, which +// in turn is enabled by SPV_INTEL_subgroups. + +TEST_F(ValidateCapability, InstructionFromExtensionMissingEnabledByCapability) { + // Decoration ViewportRelativeNV is enabled by ShaderViewportMaskNV, which in + // turn is enabled by SPV_NV_viewport_array2. + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +; OpCapability SubgroupShuffleINTEL +OpExtension "SPV_INTEL_subgroups" +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%voidfn = OpTypeFunction %void +%zero = OpConstant %uint 0 +%main = OpFunction %void None %voidfn +%entry = OpLabel +%foo = OpSubgroupShuffleINTEL %uint %zero %zero +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Opcode SubgroupShuffleINTEL requires one of these " + "capabilities: SubgroupShuffleINTEL")); +} + +TEST_F(ValidateCapability, + InstructionEnablingCapabilityEnabledByMissingExtension) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability SubgroupShuffleINTEL +; OpExtension "SPV_INTEL_subgroups" +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%voidfn = OpTypeFunction %void +%zero = OpConstant %uint 0 +%main = OpFunction %void None %voidfn +%entry = OpLabel +%foo = OpSubgroupShuffleINTEL %uint %zero %zero +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(SPV_ERROR_MISSING_EXTENSION, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("operand 5568 requires one of these extensions: " + "SPV_INTEL_subgroups")); +} + +TEST_F(ValidateCapability, + InstructionEnabledByCapabilityEnabledByPresentExtension) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability SubgroupShuffleINTEL +OpExtension "SPV_INTEL_subgroups" +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%voidfn = OpTypeFunction %void +%zero = OpConstant %uint 0 +%main = OpFunction %void None %voidfn +%entry = OpLabel +%foo = OpSubgroupShuffleINTEL %uint %zero %zero +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_0)) + << getDiagnosticString(); +} + } // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_cfg_test.cpp b/3rdparty/spirv-tools/test/val/val_cfg_test.cpp index 78cbe093d..045166925 100644 --- a/3rdparty/spirv-tools/test/val/val_cfg_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_cfg_test.cpp @@ -25,56 +25,48 @@ #include "gmock/gmock.h" #include "source/diagnostic.h" -#include "source/validate.h" -#include "test_fixture.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "source/val/validate.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" -using std::array; -using std::make_pair; -using std::pair; -using std::string; -using std::stringstream; -using std::vector; +namespace spvtools { +namespace val { +namespace { using ::testing::HasSubstr; using ::testing::MatchesRegex; -using libspirv::BasicBlock; -using libspirv::ValidationState_t; - using ValidateCFG = spvtest::ValidateBase; using spvtest::ScopedContext; -namespace { - -string nameOps() { return ""; } +std::string nameOps() { return ""; } template -string nameOps(pair head, Args... names) { +std::string nameOps(std::pair head, Args... names) { return "OpName %" + head.first + " \"" + head.second + "\"\n" + nameOps(names...); } template -string nameOps(string head, Args... names) { +std::string nameOps(std::string head, Args... names) { return "OpName %" + head + " \"" + head + "\"\n" + nameOps(names...); } /// This class allows the easy creation of complex control flow without writing /// SPIR-V. This class is used in the test cases below. class Block { - string label_; - string body_; + std::string label_; + std::string body_; SpvOp type_; - vector successors_; + std::vector successors_; public: /// Creates a Block with a given label /// /// @param[in]: label the label id of the block /// @param[in]: type the branch instruciton that ends the block - explicit Block(string label, SpvOp type = SpvOpBranch) + explicit Block(std::string label, SpvOp type = SpvOpBranch) : label_(label), body_(), type_(type), successors_() {} /// Sets the instructions which will appear in the body of the block @@ -89,8 +81,8 @@ class Block { } /// Converts the block into a SPIR-V string - operator string() { - stringstream out; + operator std::string() { + std::stringstream out; out << std::setw(8) << "%" + label_ + " = OpLabel \n"; if (!body_.empty()) { out << body_; @@ -105,7 +97,7 @@ class Block { break; case SpvOpSwitch: { out << "OpSwitch %one %" + successors_.front().label_; - stringstream ss; + std::stringstream ss; for (size_t i = 1; i < successors_.size(); i++) { ss << " " << i << " %" << successors_[i].label_; } @@ -130,12 +122,12 @@ class Block { return out.str(); } - friend Block& operator>>(Block& curr, vector successors); + friend Block& operator>>(Block& curr, std::vector successors); friend Block& operator>>(Block& lhs, Block& successor); }; /// Assigns the successors for the Block on the lhs -Block& operator>>(Block& lhs, vector successors) { +Block& operator>>(Block& lhs, std::vector successors) { if (lhs.type_ == SpvOpBranchConditional) { assert(successors.size() == 2); } else if (lhs.type_ == SpvOpSwitch) { @@ -193,7 +185,7 @@ TEST_P(ValidateCFG, LoopReachableFromEntryButNeverLeadingToReturn) { // // For more motivation, see // https://github.com/KhronosGroup/SPIRV-Tools/issues/279 - string str = R"( + std::string str = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -231,7 +223,7 @@ TEST_P(ValidateCFG, LoopUnreachableFromEntryButLeadingToReturn) { // https://github.com/KhronosGroup/SPIRV-Tools/issues/279 // Before that fix, we'd have an infinite loop when calculating // post-dominators. - string str = R"( + std::string str = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -278,13 +270,14 @@ TEST_P(ValidateCFG, Simple) { loop.SetBody("OpLoopMerge %merge %cont None\n"); } - string str = - header(GetParam()) + - nameOps("loop", "entry", "cont", "merge", make_pair("func", "Main")) + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps("loop", "entry", "cont", "merge", + std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop; - str += loop >> vector({cont, merge}); + str += loop >> std::vector({cont, merge}); str += cont >> loop; str += merge; str += "OpFunctionEnd\n"; @@ -300,8 +293,9 @@ TEST_P(ValidateCFG, Variable) { entry.SetBody("%var = OpVariable %ptrt Function\n"); - string str = header(GetParam()) + nameOps(make_pair("func", "Main")) + - types_consts() + " %func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps(std::make_pair("func", "Main")) + types_consts() + + " %func = OpFunction %voidt None %funct\n"; str += entry >> cont; str += cont >> exit; str += exit; @@ -319,8 +313,9 @@ TEST_P(ValidateCFG, VariableNotInFirstBlockBad) { // This operation should only be performed in the entry block cont.SetBody("%var = OpVariable %ptrt Function\n"); - string str = header(GetParam()) + nameOps(make_pair("func", "Main")) + - types_consts() + " %func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps(std::make_pair("func", "Main")) + types_consts() + + " %func = OpFunction %voidt None %funct\n"; str += entry >> cont; str += cont >> exit; @@ -344,13 +339,14 @@ TEST_P(ValidateCFG, BlockSelfLoopIsOk) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); if (is_shader) loop.SetBody("OpLoopMerge %merge %loop None\n"); - string str = header(GetParam()) + - nameOps("loop", "merge", make_pair("func", "Main")) + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps("loop", "merge", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop; // loop branches to itself, but does not trigger an error. - str += loop >> vector({merge, loop}); + str += loop >> std::vector({merge, loop}); str += merge; str += "OpFunctionEnd\n"; @@ -368,13 +364,14 @@ TEST_P(ValidateCFG, BlockAppearsBeforeDominatorBad) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); if (is_shader) branch.SetBody("OpSelectionMerge %merge None\n"); - string str = header(GetParam()) + - nameOps("cont", "branch", make_pair("func", "Main")) + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps("cont", "branch", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> branch; str += cont >> merge; // cont appears before its dominator - str += branch >> vector({cont, merge}); + str += branch >> std::vector({cont, merge}); str += merge; str += "OpFunctionEnd\n"; @@ -382,7 +379,8 @@ TEST_P(ValidateCFG, BlockAppearsBeforeDominatorBad) { ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), MatchesRegex("Block .\\[cont\\] appears in the binary " - "before its dominator .\\[branch\\]")); + "before its dominator .\\[branch\\]\n" + " %branch = OpLabel\n")); } TEST_P(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) { @@ -398,13 +396,13 @@ TEST_P(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) { // cannot share the same merge if (is_shader) selection.SetBody("OpSelectionMerge %merge None\n"); - string str = header(GetParam()) + - nameOps("merge", make_pair("func", "Main")) + types_consts() + - "%func = OpFunction %voidt None %funct\n"; + std::string str = + header(GetParam()) + nameOps("merge", std::make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop; str += loop >> selection; - str += selection >> vector({loop, merge}); + str += selection >> std::vector({loop, merge}); str += merge; str += "OpFunctionEnd\n"; @@ -413,7 +411,8 @@ TEST_P(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) { ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), MatchesRegex("Block .\\[merge\\] is already a merge block " - "for another header")); + "for another header\n" + " %Main = OpFunction %void None %9\n")); } else { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } @@ -432,13 +431,13 @@ TEST_P(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksSelectionBad) { // cannot share the same merge if (is_shader) loop.SetBody(" OpLoopMerge %merge %loop None\n"); - string str = header(GetParam()) + - nameOps("merge", make_pair("func", "Main")) + types_consts() + - "%func = OpFunction %voidt None %funct\n"; + std::string str = + header(GetParam()) + nameOps("merge", std::make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> selection; - str += selection >> vector({merge, loop}); - str += loop >> vector({loop, merge}); + str += selection >> std::vector({merge, loop}); + str += loop >> std::vector({loop, merge}); str += merge; str += "OpFunctionEnd\n"; @@ -447,19 +446,21 @@ TEST_P(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksSelectionBad) { ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), MatchesRegex("Block .\\[merge\\] is already a merge block " - "for another header")); + "for another header\n" + " %Main = OpFunction %void None %9\n")); } else { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } } -TEST_P(ValidateCFG, BranchTargetFirstBlockBad) { +TEST_P(ValidateCFG, BranchTargetFirstBlockBadSinceEntryBlock) { Block entry("entry"); Block bad("bad"); Block end("end", SpvOpReturn); - string str = header(GetParam()) + - nameOps("entry", "bad", make_pair("func", "Main")) + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps("entry", "bad", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> bad; str += bad >> entry; // Cannot target entry block @@ -470,7 +471,34 @@ TEST_P(ValidateCFG, BranchTargetFirstBlockBad) { ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] " - "is targeted by block .\\[bad\\]")); + "is targeted by block .\\[bad\\]\n" + " %Main = OpFunction %void None %10\n")); +} + +TEST_P(ValidateCFG, BranchTargetFirstBlockBadSinceValue) { + Block entry("entry"); + entry.SetBody("%undef = OpUndef %voidt\n"); + Block bad("bad"); + Block end("end", SpvOpReturn); + Block badvalue("undef"); // This referenes the OpUndef. + std::string str = header(GetParam()) + + nameOps("entry", "bad", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> bad; + str += + bad >> badvalue; // Check branch to a function value (it's not a block!) + str += end; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("Block\\(s\\) \\{..\\} are referenced but not " + "defined in function .\\[Main\\]\n" + " %Main = OpFunction %void None %10\n")) + << str; } TEST_P(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) { @@ -481,12 +509,13 @@ TEST_P(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); bad.SetBody(" OpLoopMerge %entry %exit None\n"); - string str = header(GetParam()) + - nameOps("entry", "bad", make_pair("func", "Main")) + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps("entry", "bad", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> bad; - str += bad >> vector({entry, exit}); // cannot target entry block + str += bad >> std::vector({entry, exit}); // cannot target entry block str += exit; str += "OpFunctionEnd\n"; @@ -494,7 +523,8 @@ TEST_P(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) { ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] " - "is targeted by block .\\[bad\\]")); + "is targeted by block .\\[bad\\]\n" + " %Main = OpFunction %void None %10\n")); } TEST_P(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) { @@ -507,12 +537,13 @@ TEST_P(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); bad.SetBody("OpLoopMerge %merge %cont None\n"); - string str = header(GetParam()) + - nameOps("entry", "bad", make_pair("func", "Main")) + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps("entry", "bad", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> bad; - str += bad >> vector({t, entry}); + str += bad >> std::vector({t, entry}); str += merge >> end; str += end; str += "OpFunctionEnd\n"; @@ -521,7 +552,8 @@ TEST_P(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) { ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] " - "is targeted by block .\\[bad\\]")); + "is targeted by block .\\[bad\\]\n" + " %Main = OpFunction %void None %10\n")); } TEST_P(ValidateCFG, SwitchTargetFirstBlockBad) { @@ -537,12 +569,13 @@ TEST_P(ValidateCFG, SwitchTargetFirstBlockBad) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); bad.SetBody("OpSelectionMerge %merge None\n"); - string str = header(GetParam()) + - nameOps("entry", "bad", make_pair("func", "Main")) + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps("entry", "bad", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> bad; - str += bad >> vector({def, block1, block2, block3, entry}); + str += bad >> std::vector({def, block1, block2, block3, entry}); str += def >> merge; str += block1 >> merge; str += block2 >> merge; @@ -555,7 +588,8 @@ TEST_P(ValidateCFG, SwitchTargetFirstBlockBad) { ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] " - "is targeted by block .\\[bad\\]")); + "is targeted by block .\\[bad\\]\n" + " %Main = OpFunction %void None %10\n")); } TEST_P(ValidateCFG, BranchToBlockInOtherFunctionBad) { @@ -570,12 +604,12 @@ TEST_P(ValidateCFG, BranchToBlockInOtherFunctionBad) { Block middle2("middle2"); Block end2("end2", SpvOpReturn); - string str = header(GetParam()) + - nameOps("middle2", make_pair("func", "Main")) + types_consts() + - "%func = OpFunction %voidt None %funct\n"; + std::string str = + header(GetParam()) + nameOps("middle2", std::make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> middle; - str += middle >> vector({end, middle2}); + str += middle >> std::vector({end, middle2}); str += end; str += "OpFunctionEnd\n"; @@ -589,8 +623,9 @@ TEST_P(ValidateCFG, BranchToBlockInOtherFunctionBad) { ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - MatchesRegex("Block\\(s\\) \\{.\\[middle2\\] .\\} are referenced but not " - "defined in function .\\[Main\\]")); + MatchesRegex("Block\\(s\\) \\{.\\[middle2\\]\\} are referenced but not " + "defined in function .\\[Main\\]\n" + " %Main = OpFunction %void None %9\n")); } TEST_P(ValidateCFG, HeaderDoesntDominatesMergeBad) { @@ -604,12 +639,13 @@ TEST_P(ValidateCFG, HeaderDoesntDominatesMergeBad) { if (is_shader) head.AppendBody("OpSelectionMerge %merge None\n"); - string str = header(GetParam()) + - nameOps("head", "merge", make_pair("func", "Main")) + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps("head", "merge", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> merge; - str += head >> vector({merge, f}); + str += head >> std::vector({merge, f}); str += f >> merge; str += merge; str += "OpFunctionEnd\n"; @@ -621,7 +657,7 @@ TEST_P(ValidateCFG, HeaderDoesntDominatesMergeBad) { getDiagnosticString(), MatchesRegex("The selection construct with the selection header " ".\\[head\\] does not dominate the merge block " - ".\\[merge\\]")); + ".\\[merge\\]\n %merge = OpLabel\n")); } else { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } @@ -638,11 +674,12 @@ TEST_P(ValidateCFG, HeaderDoesntStrictlyDominateMergeBad) { if (is_shader) head.AppendBody("OpSelectionMerge %head None\n"); - string str = header(GetParam()) + - nameOps("head", "exit", make_pair("func", "Main")) + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps("head", "exit", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; - str += head >> vector({exit, exit}); + str += head >> std::vector({exit, exit}); str += exit; str += "OpFunctionEnd\n"; @@ -653,7 +690,7 @@ TEST_P(ValidateCFG, HeaderDoesntStrictlyDominateMergeBad) { getDiagnosticString(), MatchesRegex("The selection construct with the selection header " ".\\[head\\] does not strictly dominate the merge block " - ".\\[head\\]")); + ".\\[head\\]\n %head = OpLabel\n")); } else { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()) << str; } @@ -670,12 +707,13 @@ TEST_P(ValidateCFG, UnreachableMerge) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); if (is_shader) branch.AppendBody("OpSelectionMerge %merge None\n"); - string str = header(GetParam()) + - nameOps("branch", "merge", make_pair("func", "Main")) + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps("branch", "merge", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> branch; - str += branch >> vector({t, f}); + str += branch >> std::vector({t, f}); str += t; str += f; str += merge; @@ -696,12 +734,13 @@ TEST_P(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); if (is_shader) branch.AppendBody("OpSelectionMerge %merge None\n"); - string str = header(GetParam()) + - nameOps("branch", "merge", make_pair("func", "Main")) + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps("branch", "merge", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> branch; - str += branch >> vector({t, f}); + str += branch >> std::vector({t, f}); str += t; str += f; str += merge; @@ -716,9 +755,10 @@ TEST_P(ValidateCFG, UnreachableBlock) { Block unreachable("unreachable"); Block exit("exit", SpvOpReturn); - string str = header(GetParam()) + - nameOps("unreachable", "exit", make_pair("func", "Main")) + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = + header(GetParam()) + + nameOps("unreachable", "exit", std::make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> exit; str += unreachable >> exit; @@ -740,12 +780,14 @@ TEST_P(ValidateCFG, UnreachableBranch) { unreachable.SetBody("%cond = OpSLessThan %boolt %one %two\n"); if (is_shader) unreachable.AppendBody("OpSelectionMerge %merge None\n"); - string str = header(GetParam()) + - nameOps("unreachable", "exit", make_pair("func", "Main")) + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = + header(GetParam()) + + nameOps("unreachable", "exit", std::make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> exit; - str += unreachable >> vector({unreachablechildt, unreachablechildf}); + str += + unreachable >> std::vector({unreachablechildt, unreachablechildf}); str += unreachablechildt >> merge; str += unreachablechildf >> merge; str += merge >> exit; @@ -757,8 +799,8 @@ TEST_P(ValidateCFG, UnreachableBranch) { } TEST_P(ValidateCFG, EmptyFunction) { - string str = header(GetParam()) + string(types_consts()) + - R"(%func = OpFunction %voidt None %funct + std::string str = header(GetParam()) + std::string(types_consts()) + + R"(%func = OpFunction %voidt None %funct %l = OpLabel OpReturn OpFunctionEnd)"; @@ -776,11 +818,11 @@ TEST_P(ValidateCFG, SingleBlockLoop) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); if (is_shader) loop.AppendBody("OpLoopMerge %exit %loop None\n"); - string str = header(GetParam()) + string(types_consts()) + - "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + std::string(types_consts()) + + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop; - str += loop >> vector({loop, exit}); + str += loop >> std::vector({loop, exit}); str += exit; str += "OpFunctionEnd"; @@ -805,13 +847,14 @@ TEST_P(ValidateCFG, NestedLoops) { loop2.SetBody("OpLoopMerge %loop2_merge %loop2 None\n"); } - string str = header(GetParam()) + nameOps("loop2", "loop2_merge") + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + nameOps("loop2", "loop2_merge") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop1; str += loop1 >> loop1_cont_break_block; - str += loop1_cont_break_block >> vector({loop1_merge, loop2}); - str += loop2 >> vector({loop2, loop2_merge}); + str += loop1_cont_break_block >> std::vector({loop1_merge, loop2}); + str += loop2 >> std::vector({loop2, loop2_merge}); str += loop2_merge >> loop1; str += loop1_merge >> exit; str += exit; @@ -825,8 +868,8 @@ TEST_P(ValidateCFG, NestedSelection) { bool is_shader = GetParam() == SpvCapabilityShader; Block entry("entry"); const int N = 256; - vector if_blocks; - vector merge_blocks; + std::vector if_blocks; + std::vector merge_blocks; Block inner("inner"); entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); @@ -837,21 +880,22 @@ TEST_P(ValidateCFG, NestedSelection) { merge_blocks.emplace_back("if_merge0", SpvOpReturn); for (int i = 1; i < N; i++) { - stringstream ss; + std::stringstream ss; ss << i; if_blocks.emplace_back("if" + ss.str(), SpvOpBranchConditional); if (is_shader) if_blocks[i].SetBody("OpSelectionMerge %if_merge" + ss.str() + " None\n"); merge_blocks.emplace_back("if_merge" + ss.str(), SpvOpBranch); } - string str = header(GetParam()) + string(types_consts()) + - "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + std::string(types_consts()) + + "%func = OpFunction %voidt None %funct\n"; str += entry >> if_blocks[0]; for (int i = 0; i < N - 1; i++) { - str += if_blocks[i] >> vector({if_blocks[i + 1], merge_blocks[i]}); + str += + if_blocks[i] >> std::vector({if_blocks[i + 1], merge_blocks[i]}); } - str += if_blocks.back() >> vector({inner, merge_blocks.back()}); + str += if_blocks.back() >> std::vector({inner, merge_blocks.back()}); str += inner >> merge_blocks.back(); for (int i = N - 1; i > 0; i--) { str += merge_blocks[i] >> merge_blocks[i - 1]; @@ -878,14 +922,15 @@ TEST_P(ValidateCFG, BackEdgeBlockDoesntPostDominateContinueTargetBad) { loop2.SetBody("OpLoopMerge %loop2_merge %loop2 None\n"); } - string str = header(GetParam()) + - nameOps("loop1", "loop2", "be_block", "loop2_merge") + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps("loop1", "loop2", "be_block", "loop2_merge") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop1; - str += loop1 >> vector({loop2, exit}); - str += loop2 >> vector({loop2, loop2_merge}); - str += loop2_merge >> vector({be_block, exit}); + str += loop1 >> std::vector({loop2, exit}); + str += loop2 >> std::vector({loop2, loop2_merge}); + str += loop2_merge >> std::vector({be_block, exit}); str += be_block >> loop1; str += exit; str += "OpFunctionEnd"; @@ -896,7 +941,8 @@ TEST_P(ValidateCFG, BackEdgeBlockDoesntPostDominateContinueTargetBad) { EXPECT_THAT(getDiagnosticString(), MatchesRegex("The continue construct with the continue target " ".\\[loop2_merge\\] is not post dominated by the " - "back-edge block .\\[be_block\\]")); + "back-edge block .\\[be_block\\]\n" + " %be_block = OpLabel\n")); } else { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } @@ -913,11 +959,12 @@ TEST_P(ValidateCFG, BranchingToNonLoopHeaderBlockBad) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); if (is_shader) split.SetBody("OpSelectionMerge %exit None\n"); - string str = header(GetParam()) + nameOps("split", "f") + types_consts() + - "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + nameOps("split", "f") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> split; - str += split >> vector({t, f}); + str += split >> std::vector({t, f}); str += t >> exit; str += f >> split; str += exit; @@ -929,7 +976,8 @@ TEST_P(ValidateCFG, BranchingToNonLoopHeaderBlockBad) { EXPECT_THAT( getDiagnosticString(), MatchesRegex("Back-edges \\(.\\[f\\] -> .\\[split\\]\\) can only " - "be formed between a block and a loop header.")); + "be formed between a block and a loop header.\n" + " %f = OpLabel\n")); } else { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } @@ -944,11 +992,11 @@ TEST_P(ValidateCFG, BranchingToSameNonLoopHeaderBlockBad) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); if (is_shader) split.SetBody("OpSelectionMerge %exit None\n"); - string str = header(GetParam()) + nameOps("split") + types_consts() + - "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + nameOps("split") + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> split; - str += split >> vector({split, exit}); + str += split >> std::vector({split, exit}); str += exit; str += "OpFunctionEnd"; @@ -958,7 +1006,8 @@ TEST_P(ValidateCFG, BranchingToSameNonLoopHeaderBlockBad) { EXPECT_THAT(getDiagnosticString(), MatchesRegex( "Back-edges \\(.\\[split\\] -> .\\[split\\]\\) can only be " - "formed between a block and a loop header.")); + "formed between a block and a loop header.\n" + " %split = OpLabel\n")); } else { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } @@ -975,11 +1024,12 @@ TEST_P(ValidateCFG, MultipleBackEdgeBlocksToLoopHeaderBad) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); if (is_shader) loop.SetBody("OpLoopMerge %merge %back0 None\n"); - string str = header(GetParam()) + nameOps("loop", "back0", "back1") + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + nameOps("loop", "back0", "back1") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop; - str += loop >> vector({back0, back1}); + str += loop >> std::vector({back0, back1}); str += back0 >> loop; str += back1 >> loop; str += merge; @@ -991,7 +1041,8 @@ TEST_P(ValidateCFG, MultipleBackEdgeBlocksToLoopHeaderBad) { EXPECT_THAT(getDiagnosticString(), MatchesRegex( "Loop header .\\[loop\\] is targeted by 2 back-edge blocks " - "but the standard requires exactly one")) + "but the standard requires exactly one\n" + " %loop = OpLabel\n")) << str; } else { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); @@ -1010,12 +1061,13 @@ TEST_P(ValidateCFG, ContinueTargetMustBePostDominatedByBackEdge) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); if (is_shader) loop.SetBody("OpLoopMerge %merge %cheader None\n"); - string str = header(GetParam()) + nameOps("cheader", "be_block") + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + nameOps("cheader", "be_block") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop; - str += loop >> vector({cheader, merge}); - str += cheader >> vector({exit, be_block}); + str += loop >> std::vector({cheader, merge}); + str += cheader >> std::vector({exit, be_block}); str += exit; // Branches out of a continue construct str += be_block >> loop; str += merge; @@ -1027,7 +1079,8 @@ TEST_P(ValidateCFG, ContinueTargetMustBePostDominatedByBackEdge) { EXPECT_THAT(getDiagnosticString(), MatchesRegex("The continue construct with the continue target " ".\\[cheader\\] is not post dominated by the " - "back-edge block .\\[be_block\\]")); + "back-edge block .\\[be_block\\]\n" + " %be_block = OpLabel\n")); } else { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } @@ -1043,12 +1096,13 @@ TEST_P(ValidateCFG, BranchOutOfConstructToMergeBad) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); if (is_shader) loop.SetBody("OpLoopMerge %merge %loop None\n"); - string str = header(GetParam()) + nameOps("cont", "loop") + types_consts() + - "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + nameOps("cont", "loop") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop; - str += loop >> vector({cont, merge}); - str += cont >> vector({loop, merge}); + str += loop >> std::vector({cont, merge}); + str += cont >> std::vector({loop, merge}); str += merge; str += "OpFunctionEnd"; @@ -1058,7 +1112,8 @@ TEST_P(ValidateCFG, BranchOutOfConstructToMergeBad) { EXPECT_THAT(getDiagnosticString(), MatchesRegex("The continue construct with the continue target " ".\\[loop\\] is not post dominated by the " - "back-edge block .\\[cont\\]")) + "back-edge block .\\[cont\\]\n" + " %cont = OpLabel\n")) << str; } else { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); @@ -1076,12 +1131,13 @@ TEST_P(ValidateCFG, BranchOutOfConstructBad) { entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); if (is_shader) loop.SetBody("OpLoopMerge %merge %loop None\n"); - string str = header(GetParam()) + nameOps("cont", "loop") + types_consts() + - "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + nameOps("cont", "loop") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop; - str += loop >> vector({cont, merge}); - str += cont >> vector({loop, exit}); + str += loop >> std::vector({cont, merge}); + str += cont >> std::vector({loop, exit}); str += merge >> exit; str += exit; str += "OpFunctionEnd"; @@ -1092,7 +1148,8 @@ TEST_P(ValidateCFG, BranchOutOfConstructBad) { EXPECT_THAT(getDiagnosticString(), MatchesRegex("The continue construct with the continue target " ".\\[loop\\] is not post dominated by the " - "back-edge block .\\[cont\\]")); + "back-edge block .\\[cont\\]\n" + " %cont = OpLabel\n")); } else { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } @@ -1106,7 +1163,7 @@ TEST_F(ValidateCFG, OpSwitchToUnreachableBlock) { Block def("default", SpvOpUnreachable); Block phi("phi", SpvOpReturn); - string str = R"( + std::string str = R"( OpCapability Shader OpMemoryModel Logical GLSL450 OpEntryPoint GLCompute %main "main" %id @@ -1132,7 +1189,7 @@ OpDecorate %id BuiltIn GlobalInvocationId "%x = OpCompositeExtract %u32 %idval 0\n" "%selector = OpUMod %u32 %x %three\n" "OpSelectionMerge %phi None\n"); - str += entry >> vector({def, case0, case1, case2}); + str += entry >> std::vector({def, case0, case1, case2}); str += case1 >> phi; str += def; str += phi; @@ -1145,7 +1202,7 @@ OpDecorate %id BuiltIn GlobalInvocationId } TEST_F(ValidateCFG, LoopWithZeroBackEdgesBad) { - string str = R"( + std::string str = R"( OpCapability Shader OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %main "main" @@ -1166,11 +1223,11 @@ TEST_F(ValidateCFG, LoopWithZeroBackEdgesBad) { getDiagnosticString(), MatchesRegex("Loop header .\\[loop\\] is targeted by " "0 back-edge blocks but the standard requires exactly " - "one")); + "one\n %loop = OpLabel\n")); } TEST_F(ValidateCFG, LoopWithBackEdgeFromUnreachableContinueConstructGood) { - string str = R"( + std::string str = R"( OpCapability Shader OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %main "main" @@ -1223,11 +1280,12 @@ TEST_P(ValidateCFG, inner_head.SetBody("OpSelectionMerge %inner_merge None\n"); } - string str = header(GetParam()) + nameOps("entry", "inner_merge", "exit") + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = header(GetParam()) + + nameOps("entry", "inner_merge", "exit") + types_consts() + + "%func = OpFunction %voidt None %funct\n"; - str += entry >> vector({inner_head, exit}); - str += inner_head >> vector({inner_true, inner_false}); + str += entry >> std::vector({inner_head, exit}); + str += inner_head >> std::vector({inner_true, inner_false}); str += inner_true; str += inner_false; str += inner_merge >> exit; @@ -1259,16 +1317,16 @@ TEST_P(ValidateCFG, ContinueTargetCanBeMergeBlockForNestedStructureGood) { if_head.SetBody("OpSelectionMerge %if_merge None\n"); } - string str = + std::string str = header(GetParam()) + nameOps("entry", "loop", "if_head", "if_true", "if_merge", "merge") + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop; str += loop >> if_head; - str += if_head >> vector({if_true, if_merge}); + str += if_head >> std::vector({if_true, if_merge}); str += if_true >> if_merge; - str += if_merge >> vector({loop, merge}); + str += if_merge >> std::vector({loop, merge}); str += merge; str += "OpFunctionEnd"; @@ -1290,12 +1348,13 @@ TEST_P(ValidateCFG, SingleLatchBlockMultipleBranchesToLoopHeader) { loop.SetBody("OpLoopMerge %merge %latch None\n"); } - string str = header(GetParam()) + nameOps("entry", "loop", "latch", "merge") + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = + header(GetParam()) + nameOps("entry", "loop", "latch", "merge") + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop; - str += loop >> vector({latch, merge}); - str += latch >> vector({loop, loop}); // This is the key + str += loop >> std::vector({latch, merge}); + str += latch >> std::vector({loop, loop}); // This is the key str += merge; str += "OpFunctionEnd"; @@ -1322,8 +1381,9 @@ TEST_P(ValidateCFG, SingleLatchBlockHeaderContinueTargetIsItselfGood) { loop.SetBody("OpLoopMerge %merge %loop None\n"); } - string str = header(GetParam()) + nameOps("entry", "loop", "latch", "merge") + - types_consts() + "%func = OpFunction %voidt None %funct\n"; + std::string str = + header(GetParam()) + nameOps("entry", "loop", "latch", "merge") + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop; str += loop >> latch; @@ -1396,9 +1456,393 @@ TEST_F(ValidateCFG, OpReturnInNonVoidFunc) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "OpReturn can only be called from a function with void return type")); + "OpReturn can only be called from a function with void return type.\n" + " OpReturn")); +} + +TEST_F(ValidateCFG, StructuredCFGBranchIntoSelectionBody) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%entry = OpLabel +OpSelectionMerge %merge None +OpBranchConditional %true %then %merge +%merge = OpLabel +OpBranch %then +%then = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("branches to the selection construct, but not to the " + "selection header 6\n %7 = OpLabel")); +} + +TEST_F(ValidateCFG, SwitchDefaultOnly) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpFunction %1 None %4 +%6 = OpLabel +OpSelectionMerge %7 None +OpSwitch %3 %7 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, SwitchSingleCase) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpFunction %1 None %4 +%6 = OpLabel +OpSelectionMerge %7 None +OpSwitch %3 %7 0 %8 +%8 = OpLabel +OpBranch %7 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, MultipleFallThroughBlocks) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %11 1 %12 +%10 = OpLabel +OpBranchConditional %6 %11 %12 +%11 = OpLabel +OpBranch %9 +%12 = OpLabel +OpBranch %9 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Case construct that targets 10 has branches to multiple other case " + "construct targets 12 and 11\n %10 = OpLabel")); +} + +TEST_F(ValidateCFG, MultipleFallThroughToDefault) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %11 1 %12 +%10 = OpLabel +OpBranch %9 +%11 = OpLabel +OpBranch %10 +%12 = OpLabel +OpBranch %10 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Multiple case constructs have branches to the case construct " + "that targets 10\n %10 = OpLabel")); +} + +TEST_F(ValidateCFG, MultipleFallThroughToNonDefault) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %11 1 %12 +%10 = OpLabel +OpBranch %12 +%11 = OpLabel +OpBranch %12 +%12 = OpLabel +OpBranch %9 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Multiple case constructs have branches to the case construct " + "that targets 12\n %12 = OpLabel")); +} + +TEST_F(ValidateCFG, DuplicateTargetWithFallThrough) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %10 1 %11 +%10 = OpLabel +OpBranch %11 +%11 = OpLabel +OpBranch %9 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, WrongOperandList) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %11 1 %12 +%10 = OpLabel +OpBranch %9 +%12 = OpLabel +OpBranch %11 +%11 = OpLabel +OpBranch %9 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Case construct that targets 12 has branches to the case " + "construct that targets 11, but does not immediately " + "precede it in the OpSwitch's target list\n" + " OpSwitch %uint_0 %10 0 %11 1 %12")); +} + +TEST_F(ValidateCFG, WrongOperandListThroughDefault) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %11 1 %12 +%10 = OpLabel +OpBranch %11 +%12 = OpLabel +OpBranch %10 +%11 = OpLabel +OpBranch %9 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Case construct that targets 12 has branches to the case " + "construct that targets 11, but does not immediately " + "precede it in the OpSwitch's target list\n" + " OpSwitch %uint_0 %10 0 %11 1 %12")); +} + +TEST_F(ValidateCFG, WrongOperandListNotLast) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %11 1 %12 2 %13 +%10 = OpLabel +OpBranch %9 +%12 = OpLabel +OpBranch %11 +%11 = OpLabel +OpBranch %9 +%13 = OpLabel +OpBranch %9 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Case construct that targets 12 has branches to the case " + "construct that targets 11, but does not immediately " + "precede it in the OpSwitch's target list\n" + " OpSwitch %uint_0 %10 0 %11 1 %12 2 %13")); +} + +TEST_F(ValidateCFG, InvalidCaseExit) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypeFunction %2 +%5 = OpConstant %3 0 +%1 = OpFunction %2 None %4 +%6 = OpLabel +OpSelectionMerge %7 None +OpSwitch %5 %7 0 %8 1 %9 +%8 = OpLabel +OpBranch %10 +%9 = OpLabel +OpBranch %10 +%10 = OpLabel +OpReturn +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Case construct that targets 8 has invalid branch to " + "block 10 (not another case construct, corresponding " + "merge, outer loop merge or outer loop continue")); +} + +TEST_F(ValidateCFG, GoodCaseExitsToOuterConstructs) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%int = OpTypeInt 32 0 +%int0 = OpConstant %int 0 +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +OpLoopMerge %7 %6 None +OpBranch %3 +%3 = OpLabel +OpSelectionMerge %5 None +OpSwitch %int0 %5 0 %4 +%4 = OpLabel +OpBranchConditional %true %6 %7 +%5 = OpLabel +OpBranchConditional %true %6 %7 +%6 = OpLabel +OpBranch %2 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } -/// TODO(umar): Switch instructions /// TODO(umar): Nested CFG constructs + } // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_composites_test.cpp b/3rdparty/spirv-tools/test/val/val_composites_test.cpp index 7ed8b172d..063626d1d 100644 --- a/3rdparty/spirv-tools/test/val/val_composites_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_composites_test.cpp @@ -16,9 +16,11 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::HasSubstr; @@ -185,8 +187,7 @@ TEST_F(ValidateComposites, VectorExtractDynamicWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("VectorExtractDynamic: " - "expected Result Type to be a scalar type")); + HasSubstr("Expected Result Type to be a scalar type")); } TEST_F(ValidateComposites, VectorExtractDynamicNotVector) { @@ -197,8 +198,7 @@ TEST_F(ValidateComposites, VectorExtractDynamicNotVector) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("VectorExtractDynamic: " - "expected Vector type to be OpTypeVector")); + HasSubstr("Expected Vector type to be OpTypeVector")); } TEST_F(ValidateComposites, VectorExtractDynamicWrongVectorComponent) { @@ -210,8 +210,7 @@ TEST_F(ValidateComposites, VectorExtractDynamicWrongVectorComponent) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("VectorExtractDynamic: " - "expected Vector component type to be equal to Result Type")); + HasSubstr("Expected Vector component type to be equal to Result Type")); } TEST_F(ValidateComposites, VectorExtractDynamicWrongIndexType) { @@ -222,8 +221,7 @@ TEST_F(ValidateComposites, VectorExtractDynamicWrongIndexType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("VectorExtractDynamic: " - "expected Index to be int scalar")); + HasSubstr("Expected Index to be int scalar")); } TEST_F(ValidateComposites, VectorInsertDynamicSuccess) { @@ -243,8 +241,7 @@ TEST_F(ValidateComposites, VectorInsertDynamicWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("VectorInsertDynamic: " - "expected Result Type to be OpTypeVector")); + HasSubstr("Expected Result Type to be OpTypeVector")); } TEST_F(ValidateComposites, VectorInsertDynamicNotVector) { @@ -255,8 +252,7 @@ TEST_F(ValidateComposites, VectorInsertDynamicNotVector) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("VectorInsertDynamic: " - "expected Vector type to be equal to Result Type")); + HasSubstr("Expected Vector type to be equal to Result Type")); } TEST_F(ValidateComposites, VectorInsertDynamicWrongComponentType) { @@ -267,8 +263,7 @@ TEST_F(ValidateComposites, VectorInsertDynamicWrongComponentType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("VectorInsertDynamic: " - "expected Component type to be equal to Result Type " + HasSubstr("Expected Component type to be equal to Result Type " "component type")); } @@ -280,8 +275,7 @@ TEST_F(ValidateComposites, VectorInsertDynamicWrongIndexType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("VectorInsertDynamic: " - "expected Index to be int scalar")); + HasSubstr("Expected Index to be int scalar")); } TEST_F(ValidateComposites, CompositeConstructNotComposite) { @@ -292,8 +286,7 @@ TEST_F(ValidateComposites, CompositeConstructNotComposite) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected Result Type to be a composite type")); + HasSubstr("Expected Result Type to be a composite type")); } TEST_F(ValidateComposites, CompositeConstructVectorSuccess) { @@ -316,8 +309,7 @@ TEST_F(ValidateComposites, CompositeConstructVectorOnlyOneConstituent) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected number of constituents to be at least 2")); + HasSubstr("Expected number of constituents to be at least 2")); } TEST_F(ValidateComposites, CompositeConstructVectorWrongConsituent1) { @@ -329,8 +321,7 @@ TEST_F(ValidateComposites, CompositeConstructVectorWrongConsituent1) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected Constituents to be scalars or vectors of the same " + HasSubstr("Expected Constituents to be scalars or vectors of the same " "type as Result Type components")); } @@ -343,8 +334,7 @@ TEST_F(ValidateComposites, CompositeConstructVectorWrongConsituent2) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected Constituents to be scalars or vectors of the same " + HasSubstr("Expected Constituents to be scalars or vectors of the same " "type as Result Type components")); } @@ -357,8 +347,7 @@ TEST_F(ValidateComposites, CompositeConstructVectorWrongConsituent3) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected Constituents to be scalars or vectors of the same " + HasSubstr("Expected Constituents to be scalars or vectors of the same " "type as Result Type components")); } @@ -371,8 +360,7 @@ TEST_F(ValidateComposites, CompositeConstructVectorWrongComponentNumber1) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected total number of given components to be equal to the " + HasSubstr("Expected total number of given components to be equal to the " "size of Result Type vector")); } @@ -385,8 +373,7 @@ TEST_F(ValidateComposites, CompositeConstructVectorWrongComponentNumber2) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected total number of given components to be equal to the " + HasSubstr("Expected total number of given components to be equal to the " "size of Result Type vector")); } @@ -409,8 +396,7 @@ TEST_F(ValidateComposites, CompositeConstructVectorWrongConsituentNumber1) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected total number of Constituents to be equal to the " + HasSubstr("Expected total number of Constituents to be equal to the " "number of columns of Result Type matrix")); } @@ -423,8 +409,7 @@ TEST_F(ValidateComposites, CompositeConstructVectorWrongConsituentNumber2) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected total number of Constituents to be equal to the " + HasSubstr("Expected total number of Constituents to be equal to the " "number of columns of Result Type matrix")); } @@ -437,8 +422,7 @@ TEST_F(ValidateComposites, CompositeConstructVectorWrongConsituent) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected Constituent type to be equal to the column type " + HasSubstr("Expected Constituent type to be equal to the column type " "Result Type matrix")); } @@ -460,8 +444,7 @@ TEST_F(ValidateComposites, CompositeConstructArrayWrongConsituentNumber1) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected total number of Constituents to be equal to the " + HasSubstr("Expected total number of Constituents to be equal to the " "number of elements of Result Type array")); } @@ -474,8 +457,7 @@ TEST_F(ValidateComposites, CompositeConstructArrayWrongConsituentNumber2) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected total number of Constituents to be equal to the " + HasSubstr("Expected total number of Constituents to be equal to the " "number of elements of Result Type array")); } @@ -488,8 +470,7 @@ TEST_F(ValidateComposites, CompositeConstructArrayWrongConsituent) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected Constituent type to be equal to the column type " + HasSubstr("Expected Constituent type to be equal to the column type " "Result Type array")); } @@ -511,8 +492,7 @@ TEST_F(ValidateComposites, CompositeConstructStructWrongConstituentNumber1) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected total number of Constituents to be equal to the " + HasSubstr("Expected total number of Constituents to be equal to the " "number of members of Result Type struct")); } @@ -525,8 +505,7 @@ TEST_F(ValidateComposites, CompositeConstructStructWrongConstituentNumber2) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected total number of Constituents to be equal to the " + HasSubstr("Expected total number of Constituents to be equal to the " "number of members of Result Type struct")); } @@ -538,8 +517,7 @@ TEST_F(ValidateComposites, CompositeConstructStructWrongConstituent) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeConstruct: " - "expected Constituent type to be equal to the " + HasSubstr("Expected Constituent type to be equal to the " "corresponding member type of Result Type struct")); } @@ -559,9 +537,8 @@ TEST_F(ValidateComposites, CopyObjectResultTypeNotType) { )"; CompileSuccessfully(GenerateShaderCode(body).c_str()); - ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("CopyObject: expected Result Type to be a type")); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("ID 19 is not a type id")); } TEST_F(ValidateComposites, CopyObjectWrongOperandType) { @@ -573,8 +550,7 @@ TEST_F(ValidateComposites, CopyObjectWrongOperandType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("CopyObject: " - "expected Result Type and Operand type to be the same")); + HasSubstr("Expected Result Type and Operand type to be the same")); } TEST_F(ValidateComposites, TransposeSuccess) { @@ -595,7 +571,7 @@ TEST_F(ValidateComposites, TransposeResultTypeNotMatrix) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Transpose: expected Result Type to be a matrix type")); + HasSubstr("Expected Result Type to be a matrix type")); } TEST_F(ValidateComposites, TransposeDifferentComponentTypes) { @@ -607,8 +583,7 @@ TEST_F(ValidateComposites, TransposeDifferentComponentTypes) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Transpose: " - "expected component types of Matrix and Result Type to be " + HasSubstr("Expected component types of Matrix and Result Type to be " "identical")); } @@ -619,10 +594,9 @@ TEST_F(ValidateComposites, TransposeIncompatibleDimensions1) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Transpose: expected number of columns and the column size " - "of Matrix to be the reverse of those of Result Type")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected number of columns and the column size " + "of Matrix to be the reverse of those of Result Type")); } TEST_F(ValidateComposites, TransposeIncompatibleDimensions2) { @@ -632,10 +606,9 @@ TEST_F(ValidateComposites, TransposeIncompatibleDimensions2) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Transpose: expected number of columns and the column size " - "of Matrix to be the reverse of those of Result Type")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected number of columns and the column size " + "of Matrix to be the reverse of those of Result Type")); } TEST_F(ValidateComposites, TransposeIncompatibleDimensions3) { @@ -645,10 +618,9 @@ TEST_F(ValidateComposites, TransposeIncompatibleDimensions3) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Transpose: expected number of columns and the column size " - "of Matrix to be the reverse of those of Result Type")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected number of columns and the column size " + "of Matrix to be the reverse of those of Result Type")); } TEST_F(ValidateComposites, CompositeExtractSuccess) { @@ -688,7 +660,7 @@ TEST_F(ValidateComposites, CompositeExtractNotObject) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeExtract: expected Composite to be an object " + HasSubstr("Expected Composite to be an object " "of composite type")); } @@ -700,8 +672,8 @@ TEST_F(ValidateComposites, CompositeExtractNotComposite) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCompositeExtract reached non-composite type while " - "indexes still remain to be traversed.")); + HasSubstr("Reached non-composite type while indexes still remain " + "to be traversed.")); } TEST_F(ValidateComposites, CompositeExtractVectorOutOfBounds) { @@ -712,7 +684,7 @@ TEST_F(ValidateComposites, CompositeExtractVectorOutOfBounds) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeExtract: vector access is out of bounds, " + HasSubstr("Vector access is out of bounds, " "vector size is 4, but access index is 4")); } @@ -724,7 +696,7 @@ TEST_F(ValidateComposites, CompositeExtractMatrixOutOfCols) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeExtract: matrix access is out of bounds, " + HasSubstr("Matrix access is out of bounds, " "matrix has 3 columns, but access index is 3")); } @@ -736,7 +708,7 @@ TEST_F(ValidateComposites, CompositeExtractMatrixOutOfRows) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeExtract: vector access is out of bounds, " + HasSubstr("Vector access is out of bounds, " "vector size is 2, but access index is 5")); } @@ -749,7 +721,7 @@ TEST_F(ValidateComposites, CompositeExtractArrayOutOfBounds) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeExtract: array access is out of bounds, " + HasSubstr("Array access is out of bounds, " "array size is 3, but access index is 3")); } @@ -762,9 +734,9 @@ TEST_F(ValidateComposites, CompositeExtractStructOutOfBounds) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Index is out of bounds: OpCompositeExtract can not " - "find index 6 into the structure '37'. This " - "structure has 6 members. Largest valid index is 5.")); + HasSubstr("Index is out of bounds, can not find index 6 in the " + "structure '37'. This structure has 6 members. " + "Largest valid index is 5.")); } TEST_F(ValidateComposites, CompositeExtractNestedVectorOutOfBounds) { @@ -776,7 +748,7 @@ TEST_F(ValidateComposites, CompositeExtractNestedVectorOutOfBounds) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeExtract: vector access is out of bounds, " + HasSubstr("Vector access is out of bounds, " "vector size is 2, but access index is 5")); } @@ -789,7 +761,7 @@ TEST_F(ValidateComposites, CompositeExtractTooManyIndices) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCompositeExtract reached non-composite type while " + HasSubstr("Reached non-composite type while " "indexes still remain to be traversed.")); } @@ -804,8 +776,8 @@ TEST_F(ValidateComposites, CompositeExtractWrongType1) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "OpCompositeExtract result type (OpTypeVector) does not match the " - "type that results from indexing into the composite (OpTypeFloat).")); + "Result type (OpTypeVector) does not match the type that results " + "from indexing into the composite (OpTypeFloat).")); } TEST_F(ValidateComposites, CompositeExtractWrongType2) { @@ -817,9 +789,9 @@ TEST_F(ValidateComposites, CompositeExtractWrongType2) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCompositeExtract result type (OpTypeFloat) does not " - "match the type that results from indexing into the " - "composite (OpTypeVector).")); + HasSubstr("Result type (OpTypeFloat) does not match the type " + "that results from indexing into the composite " + "(OpTypeVector).")); } TEST_F(ValidateComposites, CompositeExtractWrongType3) { @@ -831,9 +803,9 @@ TEST_F(ValidateComposites, CompositeExtractWrongType3) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCompositeExtract result type (OpTypeFloat) does not " - "match the type that results from indexing into the " - "composite (OpTypeVector).")); + HasSubstr("Result type (OpTypeFloat) does not match the type " + "that results from indexing into the composite " + "(OpTypeVector).")); } TEST_F(ValidateComposites, CompositeExtractWrongType4) { @@ -845,9 +817,9 @@ TEST_F(ValidateComposites, CompositeExtractWrongType4) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCompositeExtract result type (OpTypeFloat) does not " - "match the type that results from indexing into the " - "composite (OpTypeVector).")); + HasSubstr("Result type (OpTypeFloat) does not match the type " + "that results from indexing into the composite " + "(OpTypeVector).")); } TEST_F(ValidateComposites, CompositeExtractWrongType5) { @@ -861,7 +833,7 @@ TEST_F(ValidateComposites, CompositeExtractWrongType5) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "OpCompositeExtract result type (OpTypeFloat) does not match the " + "Result type (OpTypeFloat) does not match the " "type that results from indexing into the composite (OpTypeInt).")); } @@ -914,8 +886,8 @@ TEST_F(ValidateComposites, CompositeInsertNotComposite) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCompositeInsert reached non-composite type while " - "indexes still remain to be traversed.")); + HasSubstr("Reached non-composite type while indexes still remain " + "to be traversed.")); } TEST_F(ValidateComposites, CompositeInsertVectorOutOfBounds) { @@ -926,7 +898,7 @@ TEST_F(ValidateComposites, CompositeInsertVectorOutOfBounds) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeInsert: vector access is out of bounds, " + HasSubstr("Vector access is out of bounds, " "vector size is 4, but access index is 4")); } @@ -938,7 +910,7 @@ TEST_F(ValidateComposites, CompositeInsertMatrixOutOfCols) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeInsert: matrix access is out of bounds, " + HasSubstr("Matrix access is out of bounds, " "matrix has 3 columns, but access index is 3")); } @@ -950,7 +922,7 @@ TEST_F(ValidateComposites, CompositeInsertMatrixOutOfRows) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeInsert: vector access is out of bounds, " + HasSubstr("Vector access is out of bounds, " "vector size is 2, but access index is 5")); } @@ -963,7 +935,7 @@ TEST_F(ValidateComposites, CompositeInsertArrayOutOfBounds) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeInsert: array access is out of bounds, array " + HasSubstr("Array access is out of bounds, array " "size is 3, but access index is 3")); } @@ -976,9 +948,9 @@ TEST_F(ValidateComposites, CompositeInsertStructOutOfBounds) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Index is out of bounds: OpCompositeInsert can not " - "find index 6 into the structure '37'. This " - "structure has 6 members. Largest valid index is 5.")); + HasSubstr("Index is out of bounds, can not find index 6 in the " + "structure '37'. This structure has 6 members. " + "Largest valid index is 5.")); } TEST_F(ValidateComposites, CompositeInsertNestedVectorOutOfBounds) { @@ -990,7 +962,7 @@ TEST_F(ValidateComposites, CompositeInsertNestedVectorOutOfBounds) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("CompositeInsert: vector access is out of bounds, " + HasSubstr("Vector access is out of bounds, " "vector size is 2, but access index is 5")); } @@ -1003,8 +975,8 @@ TEST_F(ValidateComposites, CompositeInsertTooManyIndices) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCompositeInsert reached non-composite type while " - "indexes still remain to be traversed.")); + HasSubstr("Reached non-composite type while indexes still remain " + "to be traversed.")); } TEST_F(ValidateComposites, CompositeInsertWrongType1) { @@ -1016,9 +988,9 @@ TEST_F(ValidateComposites, CompositeInsertWrongType1) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("The Object type (OpTypeVector) in OpCompositeInsert " - "does not match the type that results from indexing " - "into the Composite (OpTypeFloat).")); + HasSubstr("The Object type (OpTypeVector) does not match the " + "type that results from indexing into the Composite " + "(OpTypeFloat).")); } TEST_F(ValidateComposites, CompositeInsertWrongType2) { @@ -1030,9 +1002,9 @@ TEST_F(ValidateComposites, CompositeInsertWrongType2) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("The Object type (OpTypeFloat) in OpCompositeInsert " - "does not match the type that results from indexing " - "into the Composite (OpTypeVector).")); + HasSubstr("The Object type (OpTypeFloat) does not match the type " + "that results from indexing into the Composite " + "(OpTypeVector).")); } TEST_F(ValidateComposites, CompositeInsertWrongType3) { @@ -1044,9 +1016,9 @@ TEST_F(ValidateComposites, CompositeInsertWrongType3) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("The Object type (OpTypeFloat) in OpCompositeInsert " - "does not match the type that results from indexing " - "into the Composite (OpTypeVector).")); + HasSubstr("The Object type (OpTypeFloat) does not match the type " + "that results from indexing into the Composite " + "(OpTypeVector).")); } TEST_F(ValidateComposites, CompositeInsertWrongType4) { @@ -1058,9 +1030,9 @@ TEST_F(ValidateComposites, CompositeInsertWrongType4) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("The Object type (OpTypeFloat) in OpCompositeInsert " - "does not match the type that results from indexing " - "into the Composite (OpTypeVector).")); + HasSubstr("The Object type (OpTypeFloat) does not match the type " + "that results from indexing into the Composite " + "(OpTypeVector).")); } TEST_F(ValidateComposites, CompositeInsertWrongType5) { @@ -1072,9 +1044,9 @@ TEST_F(ValidateComposites, CompositeInsertWrongType5) { CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("The Object type (OpTypeFloat) in OpCompositeInsert " - "does not match the type that results from indexing " - "into the Composite (OpTypeInt).")); + HasSubstr("The Object type (OpTypeFloat) does not match the type " + "that results from indexing into the Composite " + "(OpTypeInt).")); } // Tests ported from val_id_test.cpp. @@ -1208,9 +1180,9 @@ TEST_F(ValidateComposites, CompositeExtractNoIndexesBad) { CompileSuccessfully(spirv.str()); EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCompositeExtract result type (OpTypeFloat) does not " - "match the type that results from indexing into the " - "composite (OpTypeMatrix).")); + HasSubstr("Result type (OpTypeFloat) does not match the type " + "that results from indexing into the composite " + "(OpTypeMatrix).")); } // Valid: No Indexes were passed to OpCompositeInsert, and the type of the @@ -1241,9 +1213,9 @@ TEST_F(ValidateComposites, CompositeInsertMissingIndexesBad) { CompileSuccessfully(spirv.str()); EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("The Object type (OpTypeInt) in OpCompositeInsert does " - "not match the type that results from indexing into " - "the Composite (OpTypeMatrix).")); + HasSubstr("The Object type (OpTypeInt) does not match the type " + "that results from indexing into the Composite " + "(OpTypeMatrix).")); } // Valid: Tests that we can index into Struct, Array, Matrix, and Vector! @@ -1297,8 +1269,8 @@ TEST_F(ValidateComposites, CompositeExtractReachedScalarBad) { CompileSuccessfully(spirv.str()); EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCompositeExtract reached non-composite type while " - "indexes still remain to be traversed.")); + HasSubstr("Reached non-composite type while indexes still remain " + "to be traversed.")); } // Invalid. More indexes are provided than needed for OpCompositeInsert. @@ -1321,8 +1293,8 @@ TEST_F(ValidateComposites, CompositeInsertReachedScalarBad) { CompileSuccessfully(spirv.str()); EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCompositeInsert reached non-composite type while " - "indexes still remain to be traversed.")); + HasSubstr("Reached non-composite type while indexes still remain " + "to be traversed.")); } // Invalid. Result type doesn't match the type we get from indexing into @@ -1346,9 +1318,9 @@ TEST_F(ValidateComposites, CompileSuccessfully(spirv.str()); EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCompositeExtract result type (OpTypeInt) does not " - "match the type that results from indexing into the " - "composite (OpTypeFloat).")); + HasSubstr("Result type (OpTypeInt) does not match the type that " + "results from indexing into the composite " + "(OpTypeFloat).")); } // Invalid. Given object type doesn't match the type we get from indexing into @@ -1372,9 +1344,9 @@ TEST_F(ValidateComposites, CompositeInsertObjectTypeDoesntMatchIndexedTypeBad) { CompileSuccessfully(spirv.str()); EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("he Object type (OpTypeInt) in OpCompositeInsert does " - "not match the type that results from indexing into " - "the Composite (OpTypeFloat).")); + HasSubstr("The Object type (OpTypeInt) does not match the type " + "that results from indexing into the Composite " + "(OpTypeFloat).")); } // Invalid. Index into a struct is larger than the number of struct members. @@ -1391,9 +1363,9 @@ TEST_F(ValidateComposites, CompositeExtractStructIndexOutOfBoundBad) { CompileSuccessfully(spirv.str()); EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Index is out of bounds: OpCompositeExtract can not " - "find index 3 into the structure '26'. This " - "structure has 3 members. Largest valid index is 2.")); + HasSubstr("Index is out of bounds, can not find index 3 in the " + "structure '26'. This structure has 3 members. " + "Largest valid index is 2.")); } // Invalid. Index into a struct is larger than the number of struct members. @@ -1412,9 +1384,9 @@ TEST_F(ValidateComposites, CompositeInsertStructIndexOutOfBoundBad) { EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Index is out of bounds: OpCompositeInsert can not find " - "index 3 into the structure '26'. This structure " - "has 3 members. Largest valid index is 2.")); + HasSubstr("Index is out of bounds, can not find index 3 in the structure " + " '26'. This structure has 3 members. Largest valid index " + "is 2.")); } // #1403: Ensure that the default spec constant value is not used to check the @@ -1494,4 +1466,7 @@ OpFunctionEnd CompileSuccessfully(spirv); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } -} // anonymous namespace + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_conversion_test.cpp b/3rdparty/spirv-tools/test/val/val_conversion_test.cpp index 66a3e68f7..e0b8a0018 100644 --- a/3rdparty/spirv-tools/test/val/val_conversion_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_conversion_test.cpp @@ -17,9 +17,11 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::HasSubstr; @@ -1103,4 +1105,6 @@ TEST_F(ValidateConversion, BitcastDifferentTotalBitWidth) { "Bitcast")); } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_data_test.cpp b/3rdparty/spirv-tools/test/val/val_data_test.cpp index afa6afb6d..d022d8b8a 100644 --- a/3rdparty/spirv-tools/test/val/val_data_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_data_test.cpp @@ -19,95 +19,93 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::HasSubstr; using ::testing::MatchesRegex; -using std::pair; -using std::string; -using std::stringstream; +using ValidateData = spvtest::ValidateBase>; -using ValidateData = spvtest::ValidateBase>; - -string HeaderWith(std::string cap) { +std::string HeaderWith(std::string cap) { return std::string("OpCapability Shader OpCapability Linkage OpCapability ") + cap + " OpMemoryModel Logical GLSL450 "; } -string header = R"( +std::string header = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 )"; -string header_with_addresses = R"( +std::string header_with_addresses = R"( OpCapability Addresses OpCapability Kernel OpCapability GenericPointer OpCapability Linkage OpMemoryModel Physical32 OpenCL )"; -string header_with_vec16_cap = R"( +std::string header_with_vec16_cap = R"( OpCapability Shader OpCapability Vector16 OpCapability Linkage OpMemoryModel Logical GLSL450 )"; -string header_with_int8 = R"( +std::string header_with_int8 = R"( OpCapability Shader OpCapability Linkage OpCapability Int8 OpMemoryModel Logical GLSL450 )"; -string header_with_int16 = R"( +std::string header_with_int16 = R"( OpCapability Shader OpCapability Linkage OpCapability Int16 OpMemoryModel Logical GLSL450 )"; -string header_with_int64 = R"( +std::string header_with_int64 = R"( OpCapability Shader OpCapability Linkage OpCapability Int64 OpMemoryModel Logical GLSL450 )"; -string header_with_float16 = R"( +std::string header_with_float16 = R"( OpCapability Shader OpCapability Linkage OpCapability Float16 OpMemoryModel Logical GLSL450 )"; -string header_with_float16_buffer = R"( +std::string header_with_float16_buffer = R"( OpCapability Shader OpCapability Linkage OpCapability Float16Buffer OpMemoryModel Logical GLSL450 )"; -string header_with_float64 = R"( +std::string header_with_float64 = R"( OpCapability Shader OpCapability Linkage OpCapability Float64 OpMemoryModel Logical GLSL450 )"; -string invalid_comp_error = "Illegal number of components"; -string missing_cap_error = "requires the Vector16 capability"; -string missing_int8_cap_error = "requires the Int8 capability"; -string missing_int16_cap_error = +std::string invalid_comp_error = "Illegal number of components"; +std::string missing_cap_error = "requires the Vector16 capability"; +std::string missing_int8_cap_error = "requires the Int8 capability"; +std::string missing_int16_cap_error = "requires the Int16 capability," " or an extension that explicitly enables 16-bit integers."; -string missing_int64_cap_error = "requires the Int64 capability"; -string missing_float16_cap_error = +std::string missing_int64_cap_error = "requires the Int64 capability"; +std::string missing_float16_cap_error = "requires the Float16 or Float16Buffer capability," " or an extension that explicitly enables 16-bit floating point."; -string missing_float64_cap_error = "requires the Float64 capability"; -string invalid_num_bits_error = "Invalid number of bits"; +std::string missing_float64_cap_error = "requires the Float64 capability"; +std::string invalid_num_bits_error = "Invalid number of bits"; TEST_F(ValidateData, vec0) { - string str = header + R"( + std::string str = header + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 0 )"; @@ -117,7 +115,7 @@ TEST_F(ValidateData, vec0) { } TEST_F(ValidateData, vec1) { - string str = header + R"( + std::string str = header + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 1 )"; @@ -127,7 +125,7 @@ TEST_F(ValidateData, vec1) { } TEST_F(ValidateData, vec2) { - string str = header + R"( + std::string str = header + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 2 )"; @@ -136,7 +134,7 @@ TEST_F(ValidateData, vec2) { } TEST_F(ValidateData, vec3) { - string str = header + R"( + std::string str = header + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 3 )"; @@ -145,7 +143,7 @@ TEST_F(ValidateData, vec3) { } TEST_F(ValidateData, vec4) { - string str = header + R"( + std::string str = header + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 )"; @@ -154,7 +152,7 @@ TEST_F(ValidateData, vec4) { } TEST_F(ValidateData, vec5) { - string str = header + R"( + std::string str = header + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 5 )"; @@ -164,7 +162,7 @@ TEST_F(ValidateData, vec5) { } TEST_F(ValidateData, vec8) { - string str = header + R"( + std::string str = header + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 8 )"; @@ -174,7 +172,7 @@ TEST_F(ValidateData, vec8) { } TEST_F(ValidateData, vec8_with_capability) { - string str = header_with_vec16_cap + R"( + std::string str = header_with_vec16_cap + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 8 )"; @@ -183,7 +181,7 @@ TEST_F(ValidateData, vec8_with_capability) { } TEST_F(ValidateData, vec16) { - string str = header + R"( + std::string str = header + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 8 )"; @@ -193,7 +191,7 @@ TEST_F(ValidateData, vec16) { } TEST_F(ValidateData, vec16_with_capability) { - string str = header_with_vec16_cap + R"( + std::string str = header_with_vec16_cap + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 16 )"; @@ -202,7 +200,7 @@ TEST_F(ValidateData, vec16_with_capability) { } TEST_F(ValidateData, vec15) { - string str = header + R"( + std::string str = header + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 15 )"; @@ -212,35 +210,62 @@ TEST_F(ValidateData, vec15) { } TEST_F(ValidateData, int8_good) { - string str = header_with_int8 + "%2 = OpTypeInt 8 0"; + std::string str = header_with_int8 + "%2 = OpTypeInt 8 0"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateData, int8_bad) { - string str = header + "%2 = OpTypeInt 8 1"; + std::string str = header + "%2 = OpTypeInt 8 1"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_int8_cap_error)); } +TEST_F(ValidateData, int8_with_storage_buffer_8bit_access_good) { + std::string str = HeaderWith( + "StorageBuffer8BitAccess " + "OpExtension \"SPV_KHR_8bit_storage\"") + + " %2 = OpTypeInt 8 0"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +TEST_F(ValidateData, int8_with_uniform_and_storage_buffer_8bit_access_good) { + std::string str = HeaderWith( + "UniformAndStorageBuffer8BitAccess " + "OpExtension \"SPV_KHR_8bit_storage\"") + + " %2 = OpTypeInt 8 0"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +TEST_F(ValidateData, int8_with_storage_push_constant_8_good) { + std::string str = HeaderWith( + "StoragePushConstant8 " + "OpExtension \"SPV_KHR_8bit_storage\"") + + " %2 = OpTypeInt 8 0"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + TEST_F(ValidateData, int16_good) { - string str = header_with_int16 + "%2 = OpTypeInt 16 1"; + std::string str = header_with_int16 + "%2 = OpTypeInt 16 1"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateData, storage_uniform_buffer_block_16_good) { - string str = HeaderWith( - "StorageUniformBufferBlock16 " - "OpExtension \"SPV_KHR_16bit_storage\"") + - "%2 = OpTypeInt 16 1 %3 = OpTypeFloat 16"; + std::string str = HeaderWith( + "StorageUniformBufferBlock16 " + "OpExtension \"SPV_KHR_16bit_storage\"") + + "%2 = OpTypeInt 16 1 %3 = OpTypeFloat 16"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateData, storage_uniform_16_good) { - string str = + std::string str = HeaderWith("StorageUniform16 OpExtension \"SPV_KHR_16bit_storage\"") + "%2 = OpTypeInt 16 1 %3 = OpTypeFloat 16"; CompileSuccessfully(str.c_str()); @@ -248,38 +273,38 @@ TEST_F(ValidateData, storage_uniform_16_good) { } TEST_F(ValidateData, storage_push_constant_16_good) { - string str = HeaderWith( - "StoragePushConstant16 " - "OpExtension \"SPV_KHR_16bit_storage\"") + - "%2 = OpTypeInt 16 1 %3 = OpTypeFloat 16"; + std::string str = HeaderWith( + "StoragePushConstant16 " + "OpExtension \"SPV_KHR_16bit_storage\"") + + "%2 = OpTypeInt 16 1 %3 = OpTypeFloat 16"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateData, storage_input_output_16_good) { - string str = HeaderWith( - "StorageInputOutput16 " - "OpExtension \"SPV_KHR_16bit_storage\"") + - "%2 = OpTypeInt 16 1 %3 = OpTypeFloat 16"; + std::string str = HeaderWith( + "StorageInputOutput16 " + "OpExtension \"SPV_KHR_16bit_storage\"") + + "%2 = OpTypeInt 16 1 %3 = OpTypeFloat 16"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateData, int16_bad) { - string str = header + "%2 = OpTypeInt 16 1"; + std::string str = header + "%2 = OpTypeInt 16 1"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_int16_cap_error)); } TEST_F(ValidateData, int64_good) { - string str = header_with_int64 + "%2 = OpTypeInt 64 1"; + std::string str = header_with_int64 + "%2 = OpTypeInt 64 1"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateData, int64_bad) { - string str = header + "%2 = OpTypeInt 64 1"; + std::string str = header + "%2 = OpTypeInt 64 1"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_int64_cap_error)); @@ -287,39 +312,39 @@ TEST_F(ValidateData, int64_bad) { // Number of bits in an integer may be only one of: {8,16,32,64} TEST_F(ValidateData, int_invalid_num_bits) { - string str = header + "%2 = OpTypeInt 48 1"; + std::string str = header + "%2 = OpTypeInt 48 1"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr(invalid_num_bits_error)); } TEST_F(ValidateData, float16_good) { - string str = header_with_float16 + "%2 = OpTypeFloat 16"; + std::string str = header_with_float16 + "%2 = OpTypeFloat 16"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateData, float16_buffer_good) { - string str = header_with_float16_buffer + "%2 = OpTypeFloat 16"; + std::string str = header_with_float16_buffer + "%2 = OpTypeFloat 16"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateData, float16_bad) { - string str = header + "%2 = OpTypeFloat 16"; + std::string str = header + "%2 = OpTypeFloat 16"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_float16_cap_error)); } TEST_F(ValidateData, float64_good) { - string str = header_with_float64 + "%2 = OpTypeFloat 64"; + std::string str = header_with_float64 + "%2 = OpTypeFloat 64"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateData, float64_bad) { - string str = header + "%2 = OpTypeFloat 64"; + std::string str = header + "%2 = OpTypeFloat 64"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_float64_cap_error)); @@ -327,14 +352,14 @@ TEST_F(ValidateData, float64_bad) { // Number of bits in a float may be only one of: {16,32,64} TEST_F(ValidateData, float_invalid_num_bits) { - string str = header + "%2 = OpTypeFloat 48"; + std::string str = header + "%2 = OpTypeFloat 48"; CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr(invalid_num_bits_error)); } TEST_F(ValidateData, matrix_data_type_float) { - string str = header + R"( + std::string str = header + R"( %f32 = OpTypeFloat 32 %vec3 = OpTypeVector %f32 3 %mat33 = OpTypeMatrix %vec3 3 @@ -343,8 +368,18 @@ TEST_F(ValidateData, matrix_data_type_float) { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } +TEST_F(ValidateData, ids_should_be_validated_before_data) { + std::string str = header + R"( +%f32 = OpTypeFloat 32 +%mat33 = OpTypeMatrix %vec3 3 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("ID 3 has not been defined")); +} + TEST_F(ValidateData, matrix_bad_column_type) { - string str = header + R"( + std::string str = header + R"( %f32 = OpTypeFloat 32 %mat33 = OpTypeMatrix %f32 3 )"; @@ -355,7 +390,7 @@ TEST_F(ValidateData, matrix_bad_column_type) { } TEST_F(ValidateData, matrix_data_type_int) { - string str = header + R"( + std::string str = header + R"( %int32 = OpTypeInt 32 1 %vec3 = OpTypeVector %int32 3 %mat33 = OpTypeMatrix %vec3 3 @@ -367,7 +402,7 @@ TEST_F(ValidateData, matrix_data_type_int) { } TEST_F(ValidateData, matrix_data_type_bool) { - string str = header + R"( + std::string str = header + R"( %boolt = OpTypeBool %vec3 = OpTypeVector %boolt 3 %mat33 = OpTypeMatrix %vec3 3 @@ -379,7 +414,7 @@ TEST_F(ValidateData, matrix_data_type_bool) { } TEST_F(ValidateData, matrix_with_0_columns) { - string str = header + R"( + std::string str = header + R"( %f32 = OpTypeFloat 32 %vec3 = OpTypeVector %f32 3 %mat33 = OpTypeMatrix %vec3 0 @@ -392,7 +427,7 @@ TEST_F(ValidateData, matrix_with_0_columns) { } TEST_F(ValidateData, matrix_with_1_column) { - string str = header + R"( + std::string str = header + R"( %f32 = OpTypeFloat 32 %vec3 = OpTypeVector %f32 3 %mat33 = OpTypeMatrix %vec3 1 @@ -405,7 +440,7 @@ TEST_F(ValidateData, matrix_with_1_column) { } TEST_F(ValidateData, matrix_with_2_columns) { - string str = header + R"( + std::string str = header + R"( %f32 = OpTypeFloat 32 %vec3 = OpTypeVector %f32 3 %mat33 = OpTypeMatrix %vec3 2 @@ -415,7 +450,7 @@ TEST_F(ValidateData, matrix_with_2_columns) { } TEST_F(ValidateData, matrix_with_3_columns) { - string str = header + R"( + std::string str = header + R"( %f32 = OpTypeFloat 32 %vec3 = OpTypeVector %f32 3 %mat33 = OpTypeMatrix %vec3 3 @@ -425,7 +460,7 @@ TEST_F(ValidateData, matrix_with_3_columns) { } TEST_F(ValidateData, matrix_with_4_columns) { - string str = header + R"( + std::string str = header + R"( %f32 = OpTypeFloat 32 %vec3 = OpTypeVector %f32 3 %mat33 = OpTypeMatrix %vec3 4 @@ -435,7 +470,7 @@ TEST_F(ValidateData, matrix_with_4_columns) { } TEST_F(ValidateData, matrix_with_5_column) { - string str = header + R"( + std::string str = header + R"( %f32 = OpTypeFloat 32 %vec3 = OpTypeVector %f32 3 %mat33 = OpTypeMatrix %vec3 5 @@ -448,7 +483,7 @@ TEST_F(ValidateData, matrix_with_5_column) { } TEST_F(ValidateData, specialize_int) { - string str = header + R"( + std::string str = header + R"( %i32 = OpTypeInt 32 1 %len = OpSpecConstant %i32 2)"; CompileSuccessfully(str.c_str()); @@ -456,7 +491,7 @@ TEST_F(ValidateData, specialize_int) { } TEST_F(ValidateData, specialize_float) { - string str = header + R"( + std::string str = header + R"( %f32 = OpTypeFloat 32 %len = OpSpecConstant %f32 2)"; CompileSuccessfully(str.c_str()); @@ -464,7 +499,7 @@ TEST_F(ValidateData, specialize_float) { } TEST_F(ValidateData, specialize_boolean) { - string str = header + R"( + std::string str = header + R"( %2 = OpTypeBool %3 = OpSpecConstantTrue %2 %4 = OpSpecConstantFalse %2)"; @@ -473,7 +508,7 @@ TEST_F(ValidateData, specialize_boolean) { } TEST_F(ValidateData, specialize_boolean_to_int) { - string str = header + R"( + std::string str = header + R"( %2 = OpTypeInt 32 1 %3 = OpSpecConstantTrue %2 %4 = OpSpecConstantFalse %2)"; @@ -484,7 +519,7 @@ TEST_F(ValidateData, specialize_boolean_to_int) { } TEST_F(ValidateData, missing_forward_pointer_decl) { - string str = header_with_addresses + R"( + std::string str = header_with_addresses + R"( %uintt = OpTypeInt 32 0 %3 = OpTypeStruct %fwd_ptrt %uintt )"; @@ -494,8 +529,19 @@ TEST_F(ValidateData, missing_forward_pointer_decl) { HasSubstr("must first be declared using OpTypeForwardPointer")); } +TEST_F(ValidateData, missing_forward_pointer_decl_self_reference) { + std::string str = header_with_addresses + R"( +%uintt = OpTypeInt 32 0 +%3 = OpTypeStruct %3 %uintt +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must first be declared using OpTypeForwardPointer")); +} + TEST_F(ValidateData, forward_pointer_missing_definition) { - string str = header_with_addresses + R"( + std::string str = header_with_addresses + R"( OpTypeForwardPointer %_ptr_Generic_struct_A Generic %uintt = OpTypeInt 32 0 %struct_B = OpTypeStruct %uintt %_ptr_Generic_struct_A @@ -507,7 +553,7 @@ OpTypeForwardPointer %_ptr_Generic_struct_A Generic } TEST_F(ValidateData, forward_ref_bad_type) { - string str = header_with_addresses + R"( + std::string str = header_with_addresses + R"( OpTypeForwardPointer %_ptr_Generic_struct_A Generic %uintt = OpTypeInt 32 0 %struct_B = OpTypeStruct %uintt %_ptr_Generic_struct_A @@ -521,7 +567,7 @@ OpTypeForwardPointer %_ptr_Generic_struct_A Generic } TEST_F(ValidateData, forward_ref_points_to_non_struct) { - string str = header_with_addresses + R"( + std::string str = header_with_addresses + R"( OpTypeForwardPointer %_ptr_Generic_struct_A Generic %uintt = OpTypeInt 32 0 %struct_B = OpTypeStruct %uintt %_ptr_Generic_struct_A @@ -536,7 +582,7 @@ OpTypeForwardPointer %_ptr_Generic_struct_A Generic } TEST_F(ValidateData, struct_forward_pointer_good) { - string str = header_with_addresses + R"( + std::string str = header_with_addresses + R"( OpTypeForwardPointer %_ptr_Generic_struct_A Generic %uintt = OpTypeInt 32 0 %struct_B = OpTypeStruct %uintt %_ptr_Generic_struct_A @@ -552,15 +598,14 @@ TEST_F(ValidateData, ext_16bit_storage_caps_allow_free_fp_rounding_mode) { for (const char* cap : {"StorageUniform16", "StorageUniformBufferBlock16", "StoragePushConstant16", "StorageInputOutput16"}) { for (const char* mode : {"RTE", "RTZ", "RTP", "RTN"}) { - string str = string(R"( + std::string str = std::string(R"( OpCapability Shader OpCapability Linkage OpCapability )") + - cap + R"( + cap + R"( OpExtension "SPV_KHR_16bit_storage" OpMemoryModel Logical GLSL450 - OpDecorate %2 FPRoundingMode )" + - mode + R"( + OpDecorate %2 FPRoundingMode )" + mode + R"( %1 = OpTypeFloat 32 %2 = OpConstant %1 1.25 )"; @@ -573,11 +618,11 @@ TEST_F(ValidateData, ext_16bit_storage_caps_allow_free_fp_rounding_mode) { TEST_F(ValidateData, vulkan_disallow_free_fp_rounding_mode) { for (const char* mode : {"RTE", "RTZ"}) { for (const auto env : {SPV_ENV_VULKAN_1_0, SPV_ENV_VULKAN_1_1}) { - string str = string(R"( + std::string str = std::string(R"( OpCapability Shader OpMemoryModel Logical GLSL450 OpDecorate %2 FPRoundingMode )") + - mode + R"( + mode + R"( %1 = OpTypeFloat 32 %2 = OpConstant %1 1.25 )"; @@ -592,4 +637,6 @@ TEST_F(ValidateData, vulkan_disallow_free_fp_rounding_mode) { } } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_decoration_test.cpp b/3rdparty/spirv-tools/test/val/val_decoration_test.cpp index 2c0960c0b..c968183ca 100644 --- a/3rdparty/spirv-tools/test/val/val_decoration_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_decoration_test.cpp @@ -12,25 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Common validation fixtures for unit tests +// Validation tests for decorations + +#include +#include #include "gmock/gmock.h" #include "source/val/decoration.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::Eq; using ::testing::HasSubstr; -using libspirv::Decoration; -using std::string; -using std::vector; using ValidateDecorations = spvtest::ValidateBase; TEST_F(ValidateDecorations, ValidateOpDecorateRegistration) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -44,13 +46,14 @@ TEST_F(ValidateDecorations, ValidateOpDecorateRegistration) { CompileSuccessfully(spirv); EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); // Must have 2 decorations. - EXPECT_THAT(vstate_->id_decorations(id), - Eq(vector{Decoration(SpvDecorationArrayStride, {4}), - Decoration(SpvDecorationUniform)})); + EXPECT_THAT( + vstate_->id_decorations(id), + Eq(std::vector{Decoration(SpvDecorationArrayStride, {4}), + Decoration(SpvDecorationUniform)})); } TEST_F(ValidateDecorations, ValidateOpMemberDecorateRegistration) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -71,18 +74,19 @@ TEST_F(ValidateDecorations, ValidateOpMemberDecorateRegistration) { const uint32_t arr_id = 1; EXPECT_THAT( vstate_->id_decorations(arr_id), - Eq(vector{Decoration(SpvDecorationArrayStride, {4})})); + Eq(std::vector{Decoration(SpvDecorationArrayStride, {4})})); // The struct must have 3 decorations. const uint32_t struct_id = 2; - EXPECT_THAT(vstate_->id_decorations(struct_id), - Eq(vector{Decoration(SpvDecorationNonReadable, {}, 2), - Decoration(SpvDecorationOffset, {2}, 2), - Decoration(SpvDecorationBufferBlock)})); + EXPECT_THAT( + vstate_->id_decorations(struct_id), + Eq(std::vector{Decoration(SpvDecorationNonReadable, {}, 2), + Decoration(SpvDecorationOffset, {2}, 2), + Decoration(SpvDecorationBufferBlock)})); } TEST_F(ValidateDecorations, ValidateGroupDecorateRegistration) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -108,7 +112,7 @@ TEST_F(ValidateDecorations, ValidateGroupDecorateRegistration) { EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); // Decoration group has 3 decorations. - auto expected_decorations = vector{ + auto expected_decorations = std::vector{ Decoration(SpvDecorationDescriptorSet, {0}), Decoration(SpvDecorationNonWritable), Decoration(SpvDecorationRestrict)}; @@ -121,7 +125,7 @@ TEST_F(ValidateDecorations, ValidateGroupDecorateRegistration) { } TEST_F(ValidateDecorations, ValidateGroupMemberDecorateRegistration) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -138,7 +142,7 @@ TEST_F(ValidateDecorations, ValidateGroupMemberDecorateRegistration) { EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); // Decoration group has 1 decoration. auto expected_decorations = - vector{Decoration(SpvDecorationOffset, {3}, 3)}; + std::vector{Decoration(SpvDecorationOffset, {3}, 3)}; // Decoration group is applied to id 2, 3, and 4. EXPECT_THAT(vstate_->id_decorations(2), Eq(expected_decorations)); @@ -147,7 +151,7 @@ TEST_F(ValidateDecorations, ValidateGroupMemberDecorateRegistration) { } TEST_F(ValidateDecorations, LinkageImportUsedForInitializedVariableBad) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -164,7 +168,7 @@ TEST_F(ValidateDecorations, LinkageImportUsedForInitializedVariableBad) { "cannot be marked with the Import Linkage Type.")); } TEST_F(ValidateDecorations, LinkageExportUsedForInitializedVariableGood) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -179,7 +183,7 @@ TEST_F(ValidateDecorations, LinkageExportUsedForInitializedVariableGood) { } TEST_F(ValidateDecorations, StructAllMembersHaveBuiltInDecorationsGood) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -196,7 +200,7 @@ TEST_F(ValidateDecorations, StructAllMembersHaveBuiltInDecorationsGood) { } TEST_F(ValidateDecorations, MixedBuiltInDecorationsBad) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -218,7 +222,7 @@ TEST_F(ValidateDecorations, MixedBuiltInDecorationsBad) { } TEST_F(ValidateDecorations, StructContainsBuiltInStructBad) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -241,7 +245,7 @@ TEST_F(ValidateDecorations, StructContainsBuiltInStructBad) { } TEST_F(ValidateDecorations, StructContainsNonBuiltInStructGood) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -254,7 +258,7 @@ TEST_F(ValidateDecorations, StructContainsNonBuiltInStructGood) { } TEST_F(ValidateDecorations, MultipleBuiltInObjectsConsumedByOpEntryPointBad) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Geometry OpMemoryModel Logical GLSL450 @@ -286,7 +290,7 @@ TEST_F(ValidateDecorations, MultipleBuiltInObjectsConsumedByOpEntryPointBad) { TEST_F(ValidateDecorations, OneBuiltInObjectPerStorageClassConsumedByOpEntryPointGood) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Geometry OpMemoryModel Logical GLSL450 @@ -313,7 +317,7 @@ TEST_F(ValidateDecorations, } TEST_F(ValidateDecorations, NoBuiltInObjectsConsumedByOpEntryPointGood) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Geometry OpMemoryModel Logical GLSL450 @@ -338,7 +342,7 @@ TEST_F(ValidateDecorations, NoBuiltInObjectsConsumedByOpEntryPointGood) { } TEST_F(ValidateDecorations, EntryPointFunctionHasLinkageAttributeBad) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -361,7 +365,7 @@ TEST_F(ValidateDecorations, EntryPointFunctionHasLinkageAttributeBad) { } TEST_F(ValidateDecorations, FunctionDeclarationWithoutImportLinkageBad) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -379,7 +383,7 @@ TEST_F(ValidateDecorations, FunctionDeclarationWithoutImportLinkageBad) { } TEST_F(ValidateDecorations, FunctionDeclarationWithImportLinkageGood) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -394,7 +398,7 @@ TEST_F(ValidateDecorations, FunctionDeclarationWithImportLinkageGood) { } TEST_F(ValidateDecorations, FunctionDeclarationWithExportLinkageBad) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -413,7 +417,7 @@ TEST_F(ValidateDecorations, FunctionDeclarationWithExportLinkageBad) { } TEST_F(ValidateDecorations, FunctionDefinitionWithImportLinkageBad) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -433,7 +437,7 @@ TEST_F(ValidateDecorations, FunctionDefinitionWithImportLinkageBad) { } TEST_F(ValidateDecorations, FunctionDefinitionWithoutImportLinkageGood) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -548,4 +552,2486 @@ OpFunctionEnd "Component decorations")); } -} // anonymous namespace +// #version 440 +// #extension GL_EXT_nonuniform_qualifier : enable +// layout(binding = 1) uniform sampler2D s2d[]; +// layout(location = 0) in nonuniformEXT int i; +// void main() +// { +// vec4 v = texture(s2d[i], vec2(0.3)); +// } +TEST_F(ValidateDecorations, RuntimeArrayOfDescriptorSetsIsAllowed) { + const spv_target_env env = SPV_ENV_VULKAN_1_0; + std::string spirv = R"( + OpCapability Shader + OpCapability ShaderNonUniformEXT + OpCapability RuntimeDescriptorArrayEXT + OpCapability SampledImageArrayNonUniformIndexingEXT + OpExtension "SPV_EXT_descriptor_indexing" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %i + OpSource GLSL 440 + OpSourceExtension "GL_EXT_nonuniform_qualifier" + OpName %main "main" + OpName %v "v" + OpName %s2d "s2d" + OpName %i "i" + OpDecorate %s2d DescriptorSet 0 + OpDecorate %s2d Binding 1 + OpDecorate %i Location 0 + OpDecorate %i NonUniformEXT + OpDecorate %18 NonUniformEXT + OpDecorate %21 NonUniformEXT + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float + %10 = OpTypeImage %float 2D 0 0 0 1 Unknown + %11 = OpTypeSampledImage %10 +%_runtimearr_11 = OpTypeRuntimeArray %11 +%_ptr_UniformConstant__runtimearr_11 = OpTypePointer UniformConstant %_runtimearr_11 + %s2d = OpVariable %_ptr_UniformConstant__runtimearr_11 UniformConstant + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int + %i = OpVariable %_ptr_Input_int Input +%_ptr_UniformConstant_11 = OpTypePointer UniformConstant %11 + %v2float = OpTypeVector %float 2 +%float_0_300000012 = OpConstant %float 0.300000012 + %24 = OpConstantComposite %v2float %float_0_300000012 %float_0_300000012 + %float_0 = OpConstant %float 0 + %main = OpFunction %void None %3 + %5 = OpLabel + %v = OpVariable %_ptr_Function_v4float Function + %18 = OpLoad %int %i + %20 = OpAccessChain %_ptr_UniformConstant_11 %s2d %18 + %21 = OpLoad %11 %20 + %26 = OpImageSampleExplicitLod %v4float %21 %24 Lod %float_0 + OpStore %v %26 + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv, env); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +// #version 440 +// #extension GL_EXT_nonuniform_qualifier : enable +// layout(binding = 1) uniform sampler2D s2d[][2]; +// layout(location = 0) in nonuniformEXT int i; +// void main() +// { +// vec4 v = texture(s2d[i][i], vec2(0.3)); +// } +TEST_F(ValidateDecorations, RuntimeArrayOfArraysOfDescriptorSetsIsDisallowed) { + const spv_target_env env = SPV_ENV_VULKAN_1_0; + std::string spirv = R"( + OpCapability Shader + OpCapability ShaderNonUniformEXT + OpCapability RuntimeDescriptorArrayEXT + OpCapability SampledImageArrayNonUniformIndexingEXT + OpExtension "SPV_EXT_descriptor_indexing" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %i + OpSource GLSL 440 + OpSourceExtension "GL_EXT_nonuniform_qualifier" + OpName %main "main" + OpName %v "v" + OpName %s2d "s2d" + OpName %i "i" + OpDecorate %s2d DescriptorSet 0 + OpDecorate %s2d Binding 1 + OpDecorate %i Location 0 + OpDecorate %i NonUniformEXT + OpDecorate %21 NonUniformEXT + OpDecorate %22 NonUniformEXT + OpDecorate %25 NonUniformEXT + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float + %10 = OpTypeImage %float 2D 0 0 0 1 Unknown + %11 = OpTypeSampledImage %10 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_11_uint_2 = OpTypeArray %11 %uint_2 +%_runtimearr__arr_11_uint_2 = OpTypeRuntimeArray %_arr_11_uint_2 +%_ptr_UniformConstant__runtimearr__arr_11_uint_2 = OpTypePointer UniformConstant %_runtimearr__arr_11_uint_2 + %s2d = OpVariable %_ptr_UniformConstant__runtimearr__arr_11_uint_2 UniformConstant + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int + %i = OpVariable %_ptr_Input_int Input +%_ptr_UniformConstant_11 = OpTypePointer UniformConstant %11 + %v2float = OpTypeVector %float 2 +%float_0_300000012 = OpConstant %float 0.300000012 + %28 = OpConstantComposite %v2float %float_0_300000012 %float_0_300000012 + %float_0 = OpConstant %float 0 + %main = OpFunction %void None %3 + %5 = OpLabel + %v = OpVariable %_ptr_Function_v4float Function + %21 = OpLoad %int %i + %22 = OpLoad %int %i + %24 = OpAccessChain %_ptr_UniformConstant_11 %s2d %21 %22 + %25 = OpLoad %11 %24 + %30 = OpImageSampleExplicitLod %v4float %25 %28 Lod %float_0 + OpStore %v %30 + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv, env); + + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState(env)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Only a single level of array is allowed for " + "descriptor set variables")); +} + +// #version 440 +// layout (set=1, binding=1) uniform sampler2D variableName[2][2]; +// void main() { +// } +TEST_F(ValidateDecorations, ArrayOfArraysOfDescriptorSetsIsDisallowed) { + const spv_target_env env = SPV_ENV_VULKAN_1_0; + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 440 + OpName %main "main" + OpName %variableName "variableName" + OpDecorate %variableName DescriptorSet 1 + OpDecorate %variableName Binding 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %7 = OpTypeImage %float 2D 0 0 0 1 Unknown + %8 = OpTypeSampledImage %7 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_8_uint_2 = OpTypeArray %8 %uint_2 +%_arr__arr_8_uint_2_uint_2 = OpTypeArray %_arr_8_uint_2 %uint_2 +%_ptr_UniformConstant__arr__arr_8_uint_2_uint_2 = OpTypePointer UniformConstant %_arr__arr_8_uint_2_uint_2 +%variableName = OpVariable %_ptr_UniformConstant__arr__arr_8_uint_2_uint_2 UniformConstant + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv, env); + + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState(env)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Only a single level of array is allowed for " + "descriptor set variables")); +} + +TEST_F(ValidateDecorations, BlockMissingOffsetBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %Output = OpTypeStruct %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must be explicitly laid out with Offset decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockMissingOffsetBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %Output = OpTypeStruct %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must be explicitly laid out with Offset decorations")); +} + +TEST_F(ValidateDecorations, BlockNestedStructMissingOffsetBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %S = OpTypeStruct %v3float %int + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must be explicitly laid out with Offset decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockNestedStructMissingOffsetBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %S = OpTypeStruct %v3float %int + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must be explicitly laid out with Offset decorations")); +} + +TEST_F(ValidateDecorations, BlockGLSLSharedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + OpDecorate %Output GLSLShared + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %Output = OpTypeStruct %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLShared decoration")); +} + +TEST_F(ValidateDecorations, BufferBlockGLSLSharedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output BufferBlock + OpDecorate %Output GLSLShared + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %Output = OpTypeStruct %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLShared decoration")); +} + +TEST_F(ValidateDecorations, BlockNestedStructGLSLSharedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpDecorate %S GLSLShared + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %int = OpTypeInt 32 1 + %S = OpTypeStruct %int + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLShared decoration")); +} + +TEST_F(ValidateDecorations, BufferBlockNestedStructGLSLSharedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpDecorate %S GLSLShared + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %int = OpTypeInt 32 1 + %S = OpTypeStruct %int + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLShared decoration")); +} + +TEST_F(ValidateDecorations, BlockGLSLPackedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + OpDecorate %Output GLSLPacked + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %Output = OpTypeStruct %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLPacked decoration")); +} + +TEST_F(ValidateDecorations, BufferBlockGLSLPackedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output BufferBlock + OpDecorate %Output GLSLPacked + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %Output = OpTypeStruct %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLPacked decoration")); +} + +TEST_F(ValidateDecorations, BlockNestedStructGLSLPackedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpDecorate %S GLSLPacked + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %int = OpTypeInt 32 1 + %S = OpTypeStruct %int + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLPacked decoration")); +} + +TEST_F(ValidateDecorations, BufferBlockNestedStructGLSLPackedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpDecorate %S GLSLPacked + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %int = OpTypeInt 32 1 + %S = OpTypeStruct %int + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLPacked decoration")); +} + +TEST_F(ValidateDecorations, BlockMissingArrayStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %int = OpTypeInt 32 1 + %int_3 = OpConstant %int 3 + %array = OpTypeArray %float %int_3 + %Output = OpTypeStruct %array +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with ArrayStride decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockMissingArrayStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output BufferBlock + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %int = OpTypeInt 32 1 + %int_3 = OpConstant %int 3 + %array = OpTypeArray %float %int_3 + %Output = OpTypeStruct %array +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with ArrayStride decorations")); +} + +TEST_F(ValidateDecorations, BlockNestedStructMissingArrayStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %int = OpTypeInt 32 1 + %int_3 = OpConstant %int 3 + %array = OpTypeArray %float %int_3 + %S = OpTypeStruct %array + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with ArrayStride decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockNestedStructMissingArrayStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %int = OpTypeInt 32 1 + %int_3 = OpConstant %int 3 + %array = OpTypeArray %float %int_3 + %S = OpTypeStruct %array + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with ArrayStride decorations")); +} + +TEST_F(ValidateDecorations, BlockMissingMatrixStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %matrix = OpTypeMatrix %v3float 4 + %Output = OpTypeStruct %matrix +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with MatrixStride decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockMissingMatrixStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output BufferBlock + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %matrix = OpTypeMatrix %v3float 4 + %Output = OpTypeStruct %matrix +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with MatrixStride decorations")); +} + +TEST_F(ValidateDecorations, BlockMissingMatrixStrideArrayBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %matrix = OpTypeMatrix %v3float 4 + %int = OpTypeInt 32 1 + %int_3 = OpConstant %int 3 + %array = OpTypeArray %matrix %int_3 + %Output = OpTypeStruct %matrix +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with MatrixStride decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockMissingMatrixStrideArrayBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output BufferBlock + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %matrix = OpTypeMatrix %v3float 4 + %int = OpTypeInt 32 1 + %int_3 = OpConstant %int 3 + %array = OpTypeArray %matrix %int_3 + %Output = OpTypeStruct %matrix +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with MatrixStride decorations")); +} + +TEST_F(ValidateDecorations, BlockNestedStructMissingMatrixStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %v4float = OpTypeVector %float 4 + %matrix = OpTypeMatrix %v3float 4 + %S = OpTypeStruct %matrix + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with MatrixStride decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockNestedStructMissingMatrixStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %v4float = OpTypeVector %float 4 + %matrix = OpTypeMatrix %v3float 4 + %S = OpTypeStruct %matrix + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with MatrixStride decorations")); +} + +TEST_F(ValidateDecorations, BlockStandardUniformBufferLayout) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %F 0 Offset 0 + OpMemberDecorate %F 1 Offset 8 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpDecorate %_arr_mat3v3float_uint_2 ArrayStride 48 + OpMemberDecorate %O 0 Offset 0 + OpMemberDecorate %O 1 Offset 16 + OpMemberDecorate %O 2 Offset 32 + OpMemberDecorate %O 3 Offset 64 + OpMemberDecorate %O 4 ColMajor + OpMemberDecorate %O 4 Offset 80 + OpMemberDecorate %O 4 MatrixStride 16 + OpDecorate %_arr_O_uint_2 ArrayStride 176 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 8 + OpMemberDecorate %Output 2 Offset 16 + OpMemberDecorate %Output 3 Offset 32 + OpMemberDecorate %Output 4 Offset 48 + OpMemberDecorate %Output 5 Offset 64 + OpMemberDecorate %Output 6 ColMajor + OpMemberDecorate %Output 6 Offset 96 + OpMemberDecorate %Output 6 MatrixStride 16 + OpMemberDecorate %Output 7 Offset 128 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %F = OpTypeStruct %int %v2uint + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 +%mat2v3float = OpTypeMatrix %v3float 2 + %v3uint = OpTypeVector %uint 3 +%mat3v3float = OpTypeMatrix %v3float 3 +%_arr_mat3v3float_uint_2 = OpTypeArray %mat3v3float %uint_2 + %O = OpTypeStruct %v3uint %v2float %_arr_float_uint_2 %v2float %_arr_mat3v3float_uint_2 +%_arr_O_uint_2 = OpTypeArray %O %uint_2 + %Output = OpTypeStruct %float %v2float %v3float %F %float %_arr_float_uint_2 %mat2v3float %_arr_O_uint_2 +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, BlockLayoutPermitsTightVec3ScalarPackingGood) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/1666 + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 12 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %v3float %float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, BlockLayoutForbidsTightScalarVec3PackingBad) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/1666 + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Structure id 2 decorated as Block for variable in Uniform " + "storage class must follow standard uniform buffer layout " + "rules: member 1 at offset 4 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, + BlockLayoutPermitsTightScalarVec3PackingWithRelaxedLayoutGood) { + // Same as previous test, but with explicit option to relax block layout. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetRelaxBlockLayout(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, + BlockLayoutPermitsTightScalarVec3PackingBadOffsetWithRelaxedLayoutBad) { + // Same as previous test, but with the vector not aligned to its scalar + // element. Use offset 5 instead of a multiple of 4. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 5 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetRelaxBlockLayout(getValidatorOptions(), true); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 2 decorated as Block for variable in Uniform storage " + "class must follow standard uniform buffer layout rules: member 1 at " + "offset 5 is not aligned to scalar element size 4")); +} + +TEST_F(ValidateDecorations, + BlockLayoutPermitsTightScalarVec3PackingWithVulkan1_1Good) { + // Same as previous test, but with Vulkan 1.1. Vulkan 1.1 included + // VK_KHR_relaxed_block_layout in core. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, BufferBlock16bitStandardStorageBufferLayout) { + std::string spirv = R"( + OpCapability Shader + OpCapability StorageUniform16 + OpExtension "SPV_KHR_16bit_storage" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %f32arr ArrayStride 4 + OpDecorate %f16arr ArrayStride 2 + OpMemberDecorate %SSBO32 0 Offset 0 + OpMemberDecorate %SSBO16 0 Offset 0 + OpDecorate %SSBO32 BufferBlock + OpDecorate %SSBO16 BufferBlock + %void = OpTypeVoid + %voidf = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %i32 = OpTypeInt 32 1 + %f32 = OpTypeFloat 32 + %uvec3 = OpTypeVector %u32 3 + %c_i32_32 = OpConstant %i32 32 +%c_i32_128 = OpConstant %i32 128 + %f32arr = OpTypeArray %f32 %c_i32_128 + %f16 = OpTypeFloat 16 + %f16arr = OpTypeArray %f16 %c_i32_128 + %SSBO32 = OpTypeStruct %f32arr + %SSBO16 = OpTypeStruct %f16arr +%_ptr_Uniform_SSBO32 = OpTypePointer Uniform %SSBO32 + %varSSBO32 = OpVariable %_ptr_Uniform_SSBO32 Uniform +%_ptr_Uniform_SSBO16 = OpTypePointer Uniform %SSBO16 + %varSSBO16 = OpVariable %_ptr_Uniform_SSBO16 Uniform + %main = OpFunction %void None %voidf + %label = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, BlockArrayBaseAlignmentGood) { + // For uniform buffer, Array base alignment is 16, and ArrayStride + // must be a multiple of 16. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 16 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_PushConstant_S = OpTypePointer PushConstant %S + %u = OpVariable %_ptr_PushConstant_S PushConstant + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, BlockArrayBadAlignmentBad) { + // For uniform buffer, Array base alignment is 16. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 8 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_Uniform_S = OpTypePointer Uniform %S + %u = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 3 decorated as Block for variable in Uniform " + "storage class must follow standard uniform buffer layout rules: " + "member 1 at offset 8 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, BlockArrayBadAlignmentWithRelaxedLayoutStillBad) { + // For uniform buffer, Array base alignment is 16, and ArrayStride + // must be a multiple of 16. This case uses relaxed block layout. Relaxed + // layout only relaxes rules for vector alignment, not array alignment. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 8 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_Uniform_S = OpTypePointer Uniform %S + %u = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + spvValidatorOptionsSetRelaxBlockLayout(getValidatorOptions(), true); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 3 decorated as Block for variable in Uniform " + "storage class must follow standard uniform buffer layout rules: " + "member 1 at offset 8 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, BlockArrayBadAlignmentWithVulkan1_1StillBad) { + // Same as previous test, but with Vulkan 1.1, which includes + // VK_KHR_relaxed_block_layout in core. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 8 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_Uniform_S = OpTypePointer Uniform %S + %u = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 3 decorated as Block for variable in Uniform " + "storage class must follow standard uniform buffer layout rules: " + "member 1 at offset 8 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, PushConstantArrayBaseAlignmentGood) { + // Tests https://github.com/KhronosGroup/SPIRV-Tools/issues/1664 + // From GLSL vertex shader: + // #version 450 + // layout(push_constant) uniform S { vec2 v; float arr[2]; } u; + // void main() { } + + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 4 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 8 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_PushConstant_S = OpTypePointer PushConstant %S + %u = OpVariable %_ptr_PushConstant_S PushConstant + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, PushConstantArrayBadAlignmentBad) { + // Like the previous test, but with offset 7 instead of 8. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 4 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 7 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_PushConstant_S = OpTypePointer PushConstant %S + %u = OpVariable %_ptr_PushConstant_S PushConstant + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 3 decorated as Block for variable in PushConstant " + "storage class must follow standard storage buffer layout rules: " + "member 1 at offset 7 is not aligned to 4")); +} + +TEST_F(ValidateDecorations, + PushConstantLayoutPermitsTightVec3ScalarPackingGood) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/1666 + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 12 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %v3float %float +%_ptr_PushConstant_S = OpTypePointer PushConstant %S + %B = OpVariable %_ptr_PushConstant_S PushConstant + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, + PushConstantLayoutForbidsTightScalarVec3PackingBad) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/1666 + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_Uniform_S = OpTypePointer PushConstant %S + %B = OpVariable %_ptr_Uniform_S PushConstant + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 2 decorated as Block for variable in PushConstant " + "storage class must follow standard storage buffer layout " + "rules: member 1 at offset 4 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, StorageBufferStorageClassArrayBaseAlignmentGood) { + // Spot check buffer rules when using StorageBuffer storage class with Block + // decoration. + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 4 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 8 + OpDecorate %S Block + OpDecorate %u DescriptorSet 0 + OpDecorate %u Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_Uniform_S = OpTypePointer StorageBuffer %S + %u = OpVariable %_ptr_Uniform_S StorageBuffer + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, StorageBufferStorageClassArrayBadAlignmentBad) { + // Like the previous test, but with offset 7. + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 4 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 7 + OpDecorate %S Block + OpDecorate %u DescriptorSet 0 + OpDecorate %u Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_Uniform_S = OpTypePointer StorageBuffer %S + %u = OpVariable %_ptr_Uniform_S StorageBuffer + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 3 decorated as Block for variable in StorageBuffer " + "storage class must follow standard storage buffer layout rules: " + "member 1 at offset 7 is not aligned to 4")); +} + +TEST_F(ValidateDecorations, BufferBlockStandardStorageBufferLayout) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %F 0 Offset 0 + OpMemberDecorate %F 1 Offset 8 + OpDecorate %_arr_float_uint_2 ArrayStride 4 + OpDecorate %_arr_mat3v3float_uint_2 ArrayStride 48 + OpMemberDecorate %O 0 Offset 0 + OpMemberDecorate %O 1 Offset 16 + OpMemberDecorate %O 2 Offset 24 + OpMemberDecorate %O 3 Offset 32 + OpMemberDecorate %O 4 ColMajor + OpMemberDecorate %O 4 Offset 48 + OpMemberDecorate %O 4 MatrixStride 16 + OpDecorate %_arr_O_uint_2 ArrayStride 144 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 8 + OpMemberDecorate %Output 2 Offset 16 + OpMemberDecorate %Output 3 Offset 32 + OpMemberDecorate %Output 4 Offset 48 + OpMemberDecorate %Output 5 Offset 52 + OpMemberDecorate %Output 6 ColMajor + OpMemberDecorate %Output 6 Offset 64 + OpMemberDecorate %Output 6 MatrixStride 16 + OpMemberDecorate %Output 7 Offset 96 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %F = OpTypeStruct %int %v2uint + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 +%mat2v3float = OpTypeMatrix %v3float 2 + %v3uint = OpTypeVector %uint 3 +%mat3v3float = OpTypeMatrix %v3float 3 +%_arr_mat3v3float_uint_2 = OpTypeArray %mat3v3float %uint_2 + %O = OpTypeStruct %v3uint %v2float %_arr_float_uint_2 %v2float %_arr_mat3v3float_uint_2 +%_arr_O_uint_2 = OpTypeArray %O %uint_2 + %Output = OpTypeStruct %float %v2float %v3float %F %float %_arr_float_uint_2 %mat2v3float %_arr_O_uint_2 +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, + StorageBufferLayoutPermitsTightVec3ScalarPackingGood) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/1666 + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 12 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %v3float %float +%_ptr_StorageBuffer_S = OpTypePointer StorageBuffer %S + %B = OpVariable %_ptr_StorageBuffer_S StorageBuffer + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, + StorageBufferLayoutForbidsTightScalarVec3PackingBad) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/1666 + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_StorageBuffer_S = OpTypePointer StorageBuffer %S + %B = OpVariable %_ptr_StorageBuffer_S StorageBuffer + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 2 decorated as Block for variable in StorageBuffer " + "storage class must follow standard storage buffer layout " + "rules: member 1 at offset 4 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, + BlockStandardUniformBufferLayoutIncorrectOffset0Bad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %F 0 Offset 0 + OpMemberDecorate %F 1 Offset 8 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpDecorate %_arr_mat3v3float_uint_2 ArrayStride 48 + OpMemberDecorate %O 0 Offset 0 + OpMemberDecorate %O 1 Offset 16 + OpMemberDecorate %O 2 Offset 24 + OpMemberDecorate %O 3 Offset 33 + OpMemberDecorate %O 4 ColMajor + OpMemberDecorate %O 4 Offset 80 + OpMemberDecorate %O 4 MatrixStride 16 + OpDecorate %_arr_O_uint_2 ArrayStride 176 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 8 + OpMemberDecorate %Output 2 Offset 16 + OpMemberDecorate %Output 3 Offset 32 + OpMemberDecorate %Output 4 Offset 48 + OpMemberDecorate %Output 5 Offset 64 + OpMemberDecorate %Output 6 ColMajor + OpMemberDecorate %Output 6 Offset 96 + OpMemberDecorate %Output 6 MatrixStride 16 + OpMemberDecorate %Output 7 Offset 128 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %F = OpTypeStruct %int %v2uint + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 +%mat2v3float = OpTypeMatrix %v3float 2 + %v3uint = OpTypeVector %uint 3 +%mat3v3float = OpTypeMatrix %v3float 3 +%_arr_mat3v3float_uint_2 = OpTypeArray %mat3v3float %uint_2 + %O = OpTypeStruct %v3uint %v2float %_arr_float_uint_2 %v2float %_arr_mat3v3float_uint_2 +%_arr_O_uint_2 = OpTypeArray %O %uint_2 + %Output = OpTypeStruct %float %v2float %v3float %F %float %_arr_float_uint_2 %mat2v3float %_arr_O_uint_2 +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Structure id 6 decorated as Block for variable in Uniform " + "storage class must follow standard uniform buffer layout " + "rules: member 2 at offset 24 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, + BlockStandardUniformBufferLayoutIncorrectOffset1Bad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %F 0 Offset 0 + OpMemberDecorate %F 1 Offset 8 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpDecorate %_arr_mat3v3float_uint_2 ArrayStride 48 + OpMemberDecorate %O 0 Offset 0 + OpMemberDecorate %O 1 Offset 16 + OpMemberDecorate %O 2 Offset 32 + OpMemberDecorate %O 3 Offset 64 + OpMemberDecorate %O 4 ColMajor + OpMemberDecorate %O 4 Offset 80 + OpMemberDecorate %O 4 MatrixStride 16 + OpDecorate %_arr_O_uint_2 ArrayStride 176 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 8 + OpMemberDecorate %Output 2 Offset 16 + OpMemberDecorate %Output 3 Offset 32 + OpMemberDecorate %Output 4 Offset 48 + OpMemberDecorate %Output 5 Offset 71 + OpMemberDecorate %Output 6 ColMajor + OpMemberDecorate %Output 6 Offset 96 + OpMemberDecorate %Output 6 MatrixStride 16 + OpMemberDecorate %Output 7 Offset 128 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %F = OpTypeStruct %int %v2uint + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 +%mat2v3float = OpTypeMatrix %v3float 2 + %v3uint = OpTypeVector %uint 3 +%mat3v3float = OpTypeMatrix %v3float 3 +%_arr_mat3v3float_uint_2 = OpTypeArray %mat3v3float %uint_2 + %O = OpTypeStruct %v3uint %v2float %_arr_float_uint_2 %v2float %_arr_mat3v3float_uint_2 +%_arr_O_uint_2 = OpTypeArray %O %uint_2 + %Output = OpTypeStruct %float %v2float %v3float %F %float %_arr_float_uint_2 %mat2v3float %_arr_O_uint_2 +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Structure id 8 decorated as Block for variable in Uniform " + "storage class must follow standard uniform buffer layout " + "rules: member 5 at offset 71 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, BlockUniformBufferLayoutIncorrectArrayStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %F 0 Offset 0 + OpMemberDecorate %F 1 Offset 8 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpDecorate %_arr_mat3v3float_uint_2 ArrayStride 49 + OpMemberDecorate %O 0 Offset 0 + OpMemberDecorate %O 1 Offset 16 + OpMemberDecorate %O 2 Offset 32 + OpMemberDecorate %O 3 Offset 64 + OpMemberDecorate %O 4 ColMajor + OpMemberDecorate %O 4 Offset 80 + OpMemberDecorate %O 4 MatrixStride 16 + OpDecorate %_arr_O_uint_2 ArrayStride 176 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 8 + OpMemberDecorate %Output 2 Offset 16 + OpMemberDecorate %Output 3 Offset 32 + OpMemberDecorate %Output 4 Offset 48 + OpMemberDecorate %Output 5 Offset 64 + OpMemberDecorate %Output 6 ColMajor + OpMemberDecorate %Output 6 Offset 96 + OpMemberDecorate %Output 6 MatrixStride 16 + OpMemberDecorate %Output 7 Offset 128 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %F = OpTypeStruct %int %v2uint + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 +%mat2v3float = OpTypeMatrix %v3float 2 + %v3uint = OpTypeVector %uint 3 +%mat3v3float = OpTypeMatrix %v3float 3 +%_arr_mat3v3float_uint_2 = OpTypeArray %mat3v3float %uint_2 + %O = OpTypeStruct %v3uint %v2float %_arr_float_uint_2 %v2float %_arr_mat3v3float_uint_2 +%_arr_O_uint_2 = OpTypeArray %O %uint_2 + %Output = OpTypeStruct %float %v2float %v3float %F %float %_arr_float_uint_2 %mat2v3float %_arr_O_uint_2 +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 6 decorated as Block for variable in Uniform storage " + "class must follow standard uniform buffer layout rules: member 4 is " + "an array with stride 49 not satisfying alignment to 16")); +} + +TEST_F(ValidateDecorations, + BufferBlockStandardStorageBufferLayoutImproperStraddleBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 8 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %Output = OpTypeStruct %float %v3float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Structure id 3 decorated as BufferBlock for variable in " + "Uniform storage class must follow standard storage buffer " + "layout rules: member 1 at offset 8 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, + BlockUniformBufferLayoutOffsetInsideArrayPaddingBad) { + // In this case the 2nd member fits entirely within the padding. + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 20 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %Output = OpTypeStruct %_arr_float_uint_2 %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 4 decorated as Block for variable in Uniform storage " + "class must follow standard uniform buffer layout rules: member 1 at " + "offset 20 overlaps previous member ending at offset 31")); +} + +TEST_F(ValidateDecorations, + BlockUniformBufferLayoutOffsetInsideStructPaddingBad) { + // In this case the 2nd member fits entirely within the padding. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpMemberDecorate %_struct_6 0 Offset 0 + OpMemberDecorate %_struct_2 0 Offset 0 + OpMemberDecorate %_struct_2 1 Offset 4 + OpDecorate %_struct_2 Block + %void = OpTypeVoid + %4 = OpTypeFunction %void + %float = OpTypeFloat 32 + %_struct_6 = OpTypeStruct %float + %_struct_2 = OpTypeStruct %_struct_6 %float +%_ptr_Uniform__struct_2 = OpTypePointer Uniform %_struct_2 + %8 = OpVariable %_ptr_Uniform__struct_2 Uniform + %1 = OpFunction %void None %4 + %9 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 3 decorated as Block for variable in Uniform storage " + "class must follow standard uniform buffer layout rules: member 1 at " + "offset 4 overlaps previous member ending at offset 15")); +} + +TEST_F(ValidateDecorations, BlockLayoutOffsetOutOfOrderGoodUniversal1_0) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpMemberDecorate %Outer 0 Offset 4 + OpMemberDecorate %Outer 1 Offset 0 + OpDecorate %Outer Block + OpDecorate %O DescriptorSet 0 + OpDecorate %O Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %Outer = OpTypeStruct %uint %uint +%_ptr_Uniform_Outer = OpTypePointer Uniform %Outer + %O = OpVariable %_ptr_Uniform_Outer Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_UNIVERSAL_1_0)); +} + +TEST_F(ValidateDecorations, BlockLayoutOffsetOutOfOrderGoodOpenGL4_5) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpMemberDecorate %Outer 0 Offset 4 + OpMemberDecorate %Outer 1 Offset 0 + OpDecorate %Outer Block + OpDecorate %O DescriptorSet 0 + OpDecorate %O Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %Outer = OpTypeStruct %uint %uint +%_ptr_Uniform_Outer = OpTypePointer Uniform %Outer + %O = OpVariable %_ptr_Uniform_Outer Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_OPENGL_4_5)); +} + +TEST_F(ValidateDecorations, BlockLayoutOffsetOutOfOrderGoodVulkan1_1) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpMemberDecorate %Outer 0 Offset 4 + OpMemberDecorate %Outer 1 Offset 0 + OpDecorate %Outer Block + OpDecorate %O DescriptorSet 0 + OpDecorate %O Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %Outer = OpTypeStruct %uint %uint +%_ptr_Uniform_Outer = OpTypePointer Uniform %Outer + %O = OpVariable %_ptr_Uniform_Outer Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)) + << getDiagnosticString(); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, BlockLayoutOffsetOverlapBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpMemberDecorate %Outer 0 Offset 0 + OpMemberDecorate %Outer 1 Offset 16 + OpMemberDecorate %Inner 0 Offset 0 + OpMemberDecorate %Inner 1 Offset 16 + OpDecorate %Outer Block + OpDecorate %O DescriptorSet 0 + OpDecorate %O Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %Inner = OpTypeStruct %uint %uint + %Outer = OpTypeStruct %Inner %uint +%_ptr_Uniform_Outer = OpTypePointer Uniform %Outer + %O = OpVariable %_ptr_Uniform_Outer Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 3 decorated as Block for variable in Uniform storage " + "class must follow standard uniform buffer layout rules: member 1 at " + "offset 16 overlaps previous member ending at offset 31")); +} + +TEST_F(ValidateDecorations, BufferBlockEmptyStruct) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %Output 0 Offset 0 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %S = OpTypeStruct + %Output = OpTypeStruct %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, RowMajorMatrixTightPackingGood) { + // Row major matrix rule: + // A row-major matrix of C columns has a base alignment equal to + // the base alignment of a vector of C matrix components. + // Note: The "matrix component" is the scalar element type. + + // The matrix has 3 columns and 2 rows (C=3, R=2). + // So the base alignment of b is the same as a vector of 3 floats, which is 16 + // bytes. The matrix consists of two of these, and therefore occupies 2 x 16 + // bytes, or 32 bytes. + // + // So the offsets can be: + // a -> 0 + // b -> 16 + // c -> 48 + // d -> 60 ; d fits at bytes 12-15 after offset of c. Tight (vec3;float) + // packing + + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpSource GLSL 450 + OpMemberDecorate %_struct_2 0 Offset 0 + OpMemberDecorate %_struct_2 1 RowMajor + OpMemberDecorate %_struct_2 1 Offset 16 + OpMemberDecorate %_struct_2 1 MatrixStride 16 + OpMemberDecorate %_struct_2 2 Offset 48 + OpMemberDecorate %_struct_2 3 Offset 60 + OpDecorate %_struct_2 Block + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + %void = OpTypeVoid + %5 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%mat3v2float = OpTypeMatrix %v2float 3 + %v3float = OpTypeVector %float 3 + %_struct_2 = OpTypeStruct %v4float %mat3v2float %v3float %float +%_ptr_Uniform__struct_2 = OpTypePointer Uniform %_struct_2 + %3 = OpVariable %_ptr_Uniform__struct_2 Uniform + %1 = OpFunction %void None %5 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, ArrayArrayRowMajorMatrixTightPackingGood) { + // Like the previous case, but we have an array of arrays of matrices. + // The RowMajor decoration goes on the struct member (surprisingly). + + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpSource GLSL 450 + OpMemberDecorate %_struct_2 0 Offset 0 + OpMemberDecorate %_struct_2 1 RowMajor + OpMemberDecorate %_struct_2 1 Offset 16 + OpMemberDecorate %_struct_2 1 MatrixStride 16 + OpMemberDecorate %_struct_2 2 Offset 80 + OpMemberDecorate %_struct_2 3 Offset 92 + OpDecorate %arr_mat ArrayStride 32 + OpDecorate %arr_arr_mat ArrayStride 32 + OpDecorate %_struct_2 Block + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + %void = OpTypeVoid + %5 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%mat3v2float = OpTypeMatrix %v2float 3 +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 + %arr_mat = OpTypeArray %mat3v2float %uint_1 +%arr_arr_mat = OpTypeArray %arr_mat %uint_2 + %v3float = OpTypeVector %float 3 + %_struct_2 = OpTypeStruct %v4float %arr_arr_mat %v3float %float +%_ptr_Uniform__struct_2 = OpTypePointer Uniform %_struct_2 + %3 = OpVariable %_ptr_Uniform__struct_2 Uniform + %1 = OpFunction %void None %5 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, ArrayArrayRowMajorMatrixNextMemberOverlapsBad) { + // Like the previous case, but the offset of member 2 overlaps the matrix. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpSource GLSL 450 + OpMemberDecorate %_struct_2 0 Offset 0 + OpMemberDecorate %_struct_2 1 RowMajor + OpMemberDecorate %_struct_2 1 Offset 16 + OpMemberDecorate %_struct_2 1 MatrixStride 16 + OpMemberDecorate %_struct_2 2 Offset 64 + OpMemberDecorate %_struct_2 3 Offset 92 + OpDecorate %arr_mat ArrayStride 32 + OpDecorate %arr_arr_mat ArrayStride 32 + OpDecorate %_struct_2 Block + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + %void = OpTypeVoid + %5 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%mat3v2float = OpTypeMatrix %v2float 3 +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 + %arr_mat = OpTypeArray %mat3v2float %uint_1 +%arr_arr_mat = OpTypeArray %arr_mat %uint_2 + %v3float = OpTypeVector %float 3 + %_struct_2 = OpTypeStruct %v4float %arr_arr_mat %v3float %float +%_ptr_Uniform__struct_2 = OpTypePointer Uniform %_struct_2 + %3 = OpVariable %_ptr_Uniform__struct_2 Uniform + %1 = OpFunction %void None %5 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 2 decorated as Block for variable in Uniform storage " + "class must follow standard uniform buffer layout rules: member 2 at " + "offset 64 overlaps previous member ending at offset 79")); +} + +TEST_F(ValidateDecorations, StorageBufferArraySizeCalculationPackGood) { + // Original GLSL + + // #version 450 + // layout (set=0,binding=0) buffer S { + // uvec3 arr[2][2]; // first 3 elements are 16 bytes, last is 12 + // uint i; // Can have offset 60 = 3x16 + 12 + // } B; + // void main() {} + + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpDecorate %_arr_v3uint_uint_2 ArrayStride 16 + OpDecorate %_arr__arr_v3uint_uint_2_uint_2 ArrayStride 32 + OpMemberDecorate %_struct_4 0 Offset 0 + OpMemberDecorate %_struct_4 1 Offset 60 + OpDecorate %_struct_4 BufferBlock + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + %void = OpTypeVoid + %7 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %uint_2 = OpConstant %uint 2 +%_arr_v3uint_uint_2 = OpTypeArray %v3uint %uint_2 +%_arr__arr_v3uint_uint_2_uint_2 = OpTypeArray %_arr_v3uint_uint_2 %uint_2 + %_struct_4 = OpTypeStruct %_arr__arr_v3uint_uint_2_uint_2 %uint +%_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4 + %5 = OpVariable %_ptr_Uniform__struct_4 Uniform + %1 = OpFunction %void None %7 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, StorageBufferArraySizeCalculationPackBad) { + // Like previous but, the offset of the second member is too small. + + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpDecorate %_arr_v3uint_uint_2 ArrayStride 16 + OpDecorate %_arr__arr_v3uint_uint_2_uint_2 ArrayStride 32 + OpMemberDecorate %_struct_4 0 Offset 0 + OpMemberDecorate %_struct_4 1 Offset 56 + OpDecorate %_struct_4 BufferBlock + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + %void = OpTypeVoid + %7 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %uint_2 = OpConstant %uint 2 +%_arr_v3uint_uint_2 = OpTypeArray %v3uint %uint_2 +%_arr__arr_v3uint_uint_2_uint_2 = OpTypeArray %_arr_v3uint_uint_2 %uint_2 + %_struct_4 = OpTypeStruct %_arr__arr_v3uint_uint_2_uint_2 %uint +%_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4 + %5 = OpVariable %_ptr_Uniform__struct_4 Uniform + %1 = OpFunction %void None %7 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Structure id 4 decorated as BufferBlock for variable " + "in Uniform storage class must follow standard storage " + "buffer layout rules: member 1 at offset 56 overlaps " + "previous member ending at offset 59")); +} + +TEST_F(ValidateDecorations, UniformBufferArraySizeCalculationPackGood) { + // Like the corresponding buffer block case, but the array padding must + // count for the last element as well, and so the offset of the second + // member must be at least 64. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpDecorate %_arr_v3uint_uint_2 ArrayStride 16 + OpDecorate %_arr__arr_v3uint_uint_2_uint_2 ArrayStride 32 + OpMemberDecorate %_struct_4 0 Offset 0 + OpMemberDecorate %_struct_4 1 Offset 64 + OpDecorate %_struct_4 Block + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + %void = OpTypeVoid + %7 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %uint_2 = OpConstant %uint 2 +%_arr_v3uint_uint_2 = OpTypeArray %v3uint %uint_2 +%_arr__arr_v3uint_uint_2_uint_2 = OpTypeArray %_arr_v3uint_uint_2 %uint_2 + %_struct_4 = OpTypeStruct %_arr__arr_v3uint_uint_2_uint_2 %uint +%_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4 + %5 = OpVariable %_ptr_Uniform__struct_4 Uniform + %1 = OpFunction %void None %7 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, UniformBufferArraySizeCalculationPackBad) { + // Like previous but, the offset of the second member is too small. + + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpDecorate %_arr_v3uint_uint_2 ArrayStride 16 + OpDecorate %_arr__arr_v3uint_uint_2_uint_2 ArrayStride 32 + OpMemberDecorate %_struct_4 0 Offset 0 + OpMemberDecorate %_struct_4 1 Offset 60 + OpDecorate %_struct_4 Block + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + %void = OpTypeVoid + %7 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %uint_2 = OpConstant %uint 2 +%_arr_v3uint_uint_2 = OpTypeArray %v3uint %uint_2 +%_arr__arr_v3uint_uint_2_uint_2 = OpTypeArray %_arr_v3uint_uint_2 %uint_2 + %_struct_4 = OpTypeStruct %_arr__arr_v3uint_uint_2_uint_2 %uint +%_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4 + %5 = OpVariable %_ptr_Uniform__struct_4 Uniform + %1 = OpFunction %void None %7 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 4 decorated as Block for variable in Uniform storage " + "class must follow standard uniform buffer layout rules: member 1 at " + "offset 60 overlaps previous member ending at offset 63")); +} + +TEST_F(ValidateDecorations, LayoutNotCheckedWhenSkipBlockLayout) { + // Checks that block layout is not verified in skipping block layout mode. + // Even for obviously wrong layout. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 3 ; wrong alignment + OpMemberDecorate %S 1 Offset 3 ; same offset as before! + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetSkipBlockLayout(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, EntryPointVariableWrongStorageClass) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" %var +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_Workgroup = OpTypePointer Workgroup %int +%var = OpVariable %ptr_int_Workgroup Workgroup +%func_ty = OpTypeFunction %void +%1 = OpFunction %void None %func_ty +%2 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpEntryPoint interfaces must be OpVariables with " + "Storage Class of Input(1) or Output(3). Found Storage " + "Class 4 for Entry Point id 1.")); +} +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_derivatives_test.cpp b/3rdparty/spirv-tools/test/val/val_derivatives_test.cpp index 864dad096..93a70e87c 100644 --- a/3rdparty/spirv-tools/test/val/val_derivatives_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_derivatives_test.cpp @@ -16,9 +16,11 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::HasSubstr; @@ -38,7 +40,10 @@ OpCapability DerivativeControl ss << capabilities_and_extensions; ss << "OpMemoryModel Logical GLSL450\n"; - ss << "OpEntryPoint " << execution_model << " %main \"main\"\n"; + ss << "OpEntryPoint " << execution_model << " %main \"main\"" + << " %f32_var_input" + << " %f32vec4_var_input" + << "\n"; ss << R"( %void = OpTypeVoid @@ -145,4 +150,6 @@ TEST_F(ValidateDerivatives, OpDPdxWrongExecutionModel) { "Derivative instructions require Fragment execution model: DPdx")); } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_explicit_reserved_test.cpp b/3rdparty/spirv-tools/test/val/val_explicit_reserved_test.cpp new file mode 100644 index 000000000..f01e933fa --- /dev/null +++ b/3rdparty/spirv-tools/test/val/val_explicit_reserved_test.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for illegal instructions + +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::Eq; +using ::testing::HasSubstr; + +using ReservedSamplingInstTest = spvtest::ValidateBase; + +// Generate a shader for use with validation tests for sparse sampling +// instructions. +std::string ShaderAssembly(const std::string& instruction_under_test) { + std::ostringstream os; + os << R"( OpCapability Shader + OpCapability SparseResidency + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + OpSource GLSL 450 + OpDecorate %2 DescriptorSet 0 + OpDecorate %2 Binding 0 + %void = OpTypeVoid + %4 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %float_0 = OpConstant %float 0 + %8 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 + %9 = OpTypeImage %float 2D 0 0 0 1 Unknown + %10 = OpTypeSampledImage %9 +%_ptr_UniformConstant_10 = OpTypePointer UniformConstant %10 + %2 = OpVariable %_ptr_UniformConstant_10 UniformConstant + %v2float = OpTypeVector %float 2 + %13 = OpConstantComposite %v2float %float_0 %float_0 + %int = OpTypeInt 32 1 + %_struct_15 = OpTypeStruct %int %v4float + %1 = OpFunction %void None %4 + %16 = OpLabel + %17 = OpLoad %10 %2 +)" << instruction_under_test + << R"( + OpReturn + OpFunctionEnd +)"; + + return os.str(); +} + +TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjImplicitLod) { + const std::string input = ShaderAssembly( + "%result = OpImageSparseSampleProjImplicitLod %_struct_15 %17 %13"); + CompileSuccessfully(input); + + EXPECT_THAT(ValidateInstructions(), Eq(SPV_ERROR_INVALID_BINARY)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Invalid Opcode name 'OpImageSparseSampleProjImplicitLod'")); +} + +TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjExplicitLod) { + const std::string input = ShaderAssembly( + "%result = OpImageSparseSampleProjExplicitLod %_struct_15 %17 %13 Lod " + "%float_0\n"); + CompileSuccessfully(input); + + EXPECT_THAT(ValidateInstructions(), Eq(SPV_ERROR_INVALID_BINARY)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Invalid Opcode name 'OpImageSparseSampleProjExplicitLod'")); +} + +TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjDrefImplicitLod) { + const std::string input = ShaderAssembly( + "%result = OpImageSparseSampleProjDrefImplicitLod %_struct_15 %17 %13 " + "%float_0\n"); + CompileSuccessfully(input); + + EXPECT_THAT(ValidateInstructions(), Eq(SPV_ERROR_INVALID_BINARY)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Invalid Opcode name 'OpImageSparseSampleProjDrefImplicitLod'")); +} + +TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjDrefExplicitLod) { + const std::string input = ShaderAssembly( + "%result = OpImageSparseSampleProjDrefExplicitLod %_struct_15 %17 %13 " + "%float_0 Lod " + "%float_0\n"); + CompileSuccessfully(input); + + EXPECT_THAT(ValidateInstructions(), Eq(SPV_ERROR_INVALID_BINARY)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Invalid Opcode name 'OpImageSparseSampleProjDrefExplicitLod'")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_ext_inst_test.cpp b/3rdparty/spirv-tools/test/val/val_ext_inst_test.cpp index ee304ba76..40126fd9d 100644 --- a/3rdparty/spirv-tools/test/val/val_ext_inst_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_ext_inst_test.cpp @@ -18,13 +18,17 @@ #include #include +#include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { +using ::testing::Eq; using ::testing::HasSubstr; using ::testing::Not; @@ -91,7 +95,18 @@ OpCapability Int64 ss << capabilities_and_extensions; ss << "%extinst = OpExtInstImport \"GLSL.std.450\"\n"; ss << "OpMemoryModel Logical GLSL450\n"; - ss << "OpEntryPoint " << execution_model << " %main \"main\"\n"; + ss << "OpEntryPoint " << execution_model << " %main \"main\"" + << " %f32_output" + << " %f32vec2_output" + << " %u32_output" + << " %u32vec2_output" + << " %u64_output" + << " %f32_input" + << " %f32vec2_input" + << " %u32_input" + << " %u32vec2_input" + << " %u64_input" + << "\n"; ss << R"( %void = OpTypeVoid @@ -148,6 +163,7 @@ OpCapability Int64 %f16_0 = OpConstant %f16 0 %f16_1 = OpConstant %f16 1 +%f16_h = OpConstant %f16 0.5 %u32_0 = OpConstant %u32 0 %u32_1 = OpConstant %u32 1 @@ -216,6 +232,7 @@ OpCapability Int64 %u64_input = OpVariable %u64_ptr_input Input +%struct_f16_u16 = OpTypeStruct %f16 %u16 %struct_f32_f32 = OpTypeStruct %f32 %f32 %struct_f32_f32_f32 = OpTypeStruct %f32 %f32 %f32 %struct_f32_u32 = OpTypeStruct %f32 %u32 @@ -1599,6 +1616,36 @@ TEST_F(ValidateExtInst, GlslStd450FrexpStructXWrongType) { "member of Result Type struct")); } +TEST_F(ValidateExtInst, + GlslStd450FrexpStructResultTypeStructRightInt16Member2) { + const std::string body = R"( +%val1 = OpExtInst %struct_f16_u16 %extinst FrexpStruct %f16_h +)"; + + const std::string extension = R"( +OpExtension "SPV_AMD_gpu_shader_int16" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extension)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, + GlslStd450FrexpStructResultTypeStructWrongInt16Member2) { + const std::string body = R"( +%val1 = OpExtInst %struct_f16_u16 %extinst FrexpStruct %f16_h +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 FrexpStruct: " + "expected Result Type to be a struct with two members, " + "first member a float scalar or vector, second member " + "a 32-bit int scalar or vector with the same number of " + "components as the first member")); +} + TEST_P(ValidateGlslStd450Pack, Success) { const std::string ext_inst_name = GetParam(); const uint32_t num_components = GetPackedNumComponents(ext_inst_name); @@ -2326,21 +2373,30 @@ TEST_F(ValidateExtInst, GlslStd450RefractIntEta) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("GLSL.std.450 Refract: " - "expected operand Eta to be a 16 or 32-bit " - "float scalar")); + "expected operand Eta to be a float scalar")); } TEST_F(ValidateExtInst, GlslStd450RefractFloat64Eta) { + // SPIR-V issue 337: Eta can be 64-bit float scalar. const std::string body = R"( %val1 = OpExtInst %f32vec2 %extinst Refract %f32vec2_01 %f32vec2_01 %f64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateExtInst, GlslStd450RefractVectorEta) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst Refract %f32vec2_01 %f32vec2_01 %f32vec2_01 )"; CompileSuccessfully(GenerateShaderCode(body)); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("GLSL.std.450 Refract: " - "expected operand Eta to be a 16 or 32-bit " - "float scalar")); + "expected operand Eta to be a float scalar")); } TEST_F(ValidateExtInst, GlslStd450InterpolateAtCentroidSuccess) { @@ -5758,4 +5814,6 @@ INSTANTIATE_TEST_CASE_P(AllUpsampleLike, ValidateOpenCLStdUpsampleLike, "s_upsample", }), ); -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_extensions_test.cpp b/3rdparty/spirv-tools/test/val/val_extensions_test.cpp index d5ecacda6..b185c3ca7 100644 --- a/3rdparty/spirv-tools/test/val/val_extensions_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_extensions_test.cpp @@ -15,32 +15,31 @@ // Tests for OpExtension validator rules. #include +#include -#include "enum_string_mapping.h" -#include "extensions.h" #include "gmock/gmock.h" -#include "spirv_target_env.h" -#include "test_fixture.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "source/enum_string_mapping.h" +#include "source/extensions.h" +#include "source/spirv_target_env.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { -using ::libspirv::Extension; - using ::testing::HasSubstr; using ::testing::Not; using ::testing::Values; using ::testing::ValuesIn; -using std::string; - -using ValidateKnownExtensions = spvtest::ValidateBase; -using ValidateUnknownExtensions = spvtest::ValidateBase; +using ValidateKnownExtensions = spvtest::ValidateBase; +using ValidateUnknownExtensions = spvtest::ValidateBase; using ValidateExtensionCapabilities = spvtest::ValidateBase; // Returns expected error string if |extension| is not recognized. -string GetErrorString(const std::string& extension) { +std::string GetErrorString(const std::string& extension) { return "Found unrecognized extension " + extension; } @@ -71,7 +70,7 @@ INSTANTIATE_TEST_CASE_P(FailSilently, ValidateUnknownExtensions, TEST_P(ValidateKnownExtensions, ExpectSuccess) { const std::string extension = GetParam(); - const string str = + const std::string str = "OpCapability Shader\nOpCapability Linkage\nOpExtension \"" + extension + "\"\nOpMemoryModel Logical GLSL450"; CompileSuccessfully(str.c_str()); @@ -81,7 +80,7 @@ TEST_P(ValidateKnownExtensions, ExpectSuccess) { TEST_P(ValidateUnknownExtensions, FailSilently) { const std::string extension = GetParam(); - const string str = + const std::string str = "OpCapability Shader\nOpCapability Linkage\nOpExtension \"" + extension + "\"\nOpMemoryModel Logical GLSL450"; CompileSuccessfully(str.c_str()); @@ -90,7 +89,7 @@ TEST_P(ValidateUnknownExtensions, FailSilently) { } TEST_F(ValidateExtensionCapabilities, DeclCapabilitySuccess) { - const string str = + const std::string str = "OpCapability Shader\nOpCapability Linkage\nOpCapability DeviceGroup\n" "OpExtension \"SPV_KHR_device_group\"" "\nOpMemoryModel Logical GLSL450"; @@ -99,7 +98,7 @@ TEST_F(ValidateExtensionCapabilities, DeclCapabilitySuccess) { } TEST_F(ValidateExtensionCapabilities, DeclCapabilityFailure) { - const string str = + const std::string str = "OpCapability Shader\nOpCapability Linkage\nOpCapability DeviceGroup\n" "\nOpMemoryModel Logical GLSL450"; CompileSuccessfully(str.c_str()); @@ -110,16 +109,16 @@ TEST_F(ValidateExtensionCapabilities, DeclCapabilityFailure) { EXPECT_THAT(getDiagnosticString(), HasSubstr("SPV_KHR_device_group")); } -using ValidateAMDShaderBallotCapabilities = spvtest::ValidateBase; +using ValidateAMDShaderBallotCapabilities = spvtest::ValidateBase; // Returns a vector of strings for the prefix of a SPIR-V assembly shader // that can use the group instructions introduced by SPV_AMD_shader_ballot. -std::vector ShaderPartsForAMDShaderBallot() { - return std::vector{R"( +std::vector ShaderPartsForAMDShaderBallot() { + return std::vector{R"( OpCapability Shader OpCapability Linkage )", - R"( + R"( OpMemoryModel Logical GLSL450 %float = OpTypeFloat 32 %uint = OpTypeInt 32 0 @@ -139,8 +138,8 @@ std::vector ShaderPartsForAMDShaderBallot() { // Returns a list of SPIR-V assembly strings, where each uses only types // and IDs that can fit with a shader made from parts from the result // of ShaderPartsForAMDShaderBallot. -std::vector AMDShaderBallotGroupInstructions() { - return std::vector{ +std::vector AMDShaderBallotGroupInstructions() { + return std::vector{ "%iadd_reduce = OpGroupIAddNonUniformAMD %uint %scope Reduce %uint_const", "%iadd_iscan = OpGroupIAddNonUniformAMD %uint %scope InclusiveScan " "%uint_const", @@ -197,8 +196,9 @@ TEST_P(ValidateAMDShaderBallotCapabilities, ExpectSuccess) { // Succeed because the module specifies the SPV_AMD_shader_ballot extension. auto parts = ShaderPartsForAMDShaderBallot(); - const string assembly = parts[0] + "OpExtension \"SPV_AMD_shader_ballot\"\n" + - parts[1] + GetParam() + "\nOpReturn OpFunctionEnd"; + const std::string assembly = + parts[0] + "OpExtension \"SPV_AMD_shader_ballot\"\n" + parts[1] + + GetParam() + "\nOpReturn OpFunctionEnd"; CompileSuccessfully(assembly.c_str()); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); @@ -212,7 +212,7 @@ TEST_P(ValidateAMDShaderBallotCapabilities, ExpectFailure) { // extension. auto parts = ShaderPartsForAMDShaderBallot(); - const string assembly = + const std::string assembly = parts[0] + parts[1] + GetParam() + "\nOpReturn OpFunctionEnd"; CompileSuccessfully(assembly.c_str()); @@ -222,9 +222,10 @@ TEST_P(ValidateAMDShaderBallotCapabilities, ExpectFailure) { // Find just the opcode name, skipping over the "Op" part. auto prefix_with_opcode = GetParam().substr(GetParam().find("Group")); auto opcode = prefix_with_opcode.substr(0, prefix_with_opcode.find(' ')); - EXPECT_THAT(getDiagnosticString(), - HasSubstr(string("Opcode " + opcode + - " requires one of these capabilities: Groups"))); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr(std::string("Opcode " + opcode + + " requires one of these capabilities: Groups"))); } INSTANTIATE_TEST_CASE_P(ExpectFailure, ValidateAMDShaderBallotCapabilities, @@ -244,10 +245,10 @@ using ValidateExtIntoCore = spvtest::ValidateBase; // functionalities that introduced in extensions but became core SPIR-V later. TEST_P(ValidateExtIntoCore, DoNotAskForExtensionInLaterVersion) { - const string code = string(R"( + const std::string code = std::string(R"( OpCapability Shader OpCapability )") + - GetParam().cap + R"( + GetParam().cap + R"( OpMemoryModel Logical GLSL450 OpEntryPoint Vertex %main "main" %builtin OpDecorate %builtin BuiltIn )" + GetParam().builtin + R"( @@ -267,14 +268,14 @@ TEST_P(ValidateExtIntoCore, DoNotAskForExtensionInLaterVersion) { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(GetParam().env)); } else { ASSERT_NE(SPV_SUCCESS, ValidateInstructions(GetParam().env)); - const string message = getDiagnosticString(); + const std::string message = getDiagnosticString(); if (spvIsVulkanEnv(GetParam().env)) { - EXPECT_THAT(message, HasSubstr(string(GetParam().cap) + + EXPECT_THAT(message, HasSubstr(std::string(GetParam().cap) + " is not allowed by Vulkan")); - EXPECT_THAT(message, HasSubstr(string("or requires extension"))); + EXPECT_THAT(message, HasSubstr(std::string("or requires extension"))); } else { EXPECT_THAT(message, - HasSubstr(string("requires one of these extensions: ") + + HasSubstr(std::string("requires one of these extensions: ") + GetParam().ext)); } } @@ -316,4 +317,6 @@ INSTANTIATE_TEST_CASE_P( })); // clang-format on -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_fixtures.h b/3rdparty/spirv-tools/test/val/val_fixtures.h index 3f53e9676..73a0cc624 100644 --- a/3rdparty/spirv-tools/test/val/val_fixtures.h +++ b/3rdparty/spirv-tools/test/val/val_fixtures.h @@ -14,12 +14,15 @@ // Common validation fixtures for unit tests -#ifndef LIBSPIRV_TEST_VALIDATE_FIXTURES_H_ -#define LIBSPIRV_TEST_VALIDATE_FIXTURES_H_ +#ifndef TEST_VAL_VAL_FIXTURES_H_ +#define TEST_VAL_VAL_FIXTURES_H_ + +#include +#include #include "source/val/validation_state.h" -#include "test_fixture.h" -#include "unit_spirv.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" namespace spvtest { @@ -34,6 +37,8 @@ class ValidateBase : public ::testing::Test, // Returns the a spv_const_binary struct spv_const_binary get_const_binary(); + // Checks that 'code' is valid SPIR-V text representation and stores the + // binary version for further method calls. void CompileSuccessfully(std::string code, spv_target_env env = SPV_ENV_UNIVERSAL_1_0); @@ -43,8 +48,7 @@ class ValidateBase : public ::testing::Test, // This function overwrites the word at the given index with a new word. void OverwriteAssembledBinary(uint32_t index, uint32_t word); - // Performs validation on the SPIR-V code and compares the result of the - // spvValidate function + // Performs validation on the SPIR-V code. spv_result_t ValidateInstructions(spv_target_env env = SPV_ENV_UNIVERSAL_1_0); // Performs validation. Returns the status and stores validation state into @@ -59,7 +63,7 @@ class ValidateBase : public ::testing::Test, spv_binary binary_; spv_diagnostic diagnostic_; spv_validator_options options_; - std::unique_ptr vstate_; + std::unique_ptr vstate_; }; template @@ -113,7 +117,7 @@ spv_result_t ValidateBase::ValidateInstructions(spv_target_env env) { template spv_result_t ValidateBase::ValidateAndRetrieveValidationState( spv_target_env env) { - return spvtools::ValidateBinaryAndKeepValidationState( + return spvtools::val::ValidateBinaryAndKeepValidationState( ScopedContext(env).context, options_, get_const_binary()->code, get_const_binary()->wordCount, &diagnostic_, &vstate_); } @@ -135,4 +139,5 @@ spv_position_t ValidateBase::getErrorPosition() { } } // namespace spvtest -#endif + +#endif // TEST_VAL_VAL_FIXTURES_H_ diff --git a/3rdparty/spirv-tools/test/val/val_id_test.cpp b/3rdparty/spirv-tools/test/val/val_id_test.cpp index 195a65cf2..6d907b188 100644 --- a/3rdparty/spirv-tools/test/val/val_id_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_id_test.cpp @@ -14,11 +14,12 @@ #include #include +#include #include "gmock/gmock.h" -#include "test_fixture.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" // NOTE: The tests in this file are ONLY testing ID usage, there for the input // SPIR-V does not follow the logical layout rules from the spec in all cases in @@ -26,25 +27,38 @@ // in stages, ID validation is only one of these stages. All validation stages // are stand alone. +namespace spvtools { +namespace val { namespace { using spvtest::ScopedContext; -using std::ostringstream; -using std::string; -using std::vector; using ::testing::HasSubstr; using ::testing::ValuesIn; using ValidateIdWithMessage = spvtest::ValidateBase; -string kGLSL450MemoryModel = R"( +std::string kOpCapabilitySetup = R"( OpCapability Shader OpCapability Linkage OpCapability Addresses - OpCapability Pipes + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float64 OpCapability LiteralSampler + OpCapability Pipes OpCapability DeviceEnqueue OpCapability Vector16 +)"; + +std::string kGLSL450MemoryModel = kOpCapabilitySetup + R"( + OpMemoryModel Logical GLSL450 +)"; + +std::string kNoKernelGLSL450MemoryModel = R"( + OpCapability Shader + OpCapability Linkage + OpCapability Addresses OpCapability Int8 OpCapability Int16 OpCapability Int64 @@ -52,7 +66,7 @@ string kGLSL450MemoryModel = R"( OpMemoryModel Logical GLSL450 )"; -string kOpenCLMemoryModel32 = R"( +std::string kOpenCLMemoryModel32 = R"( OpCapability Addresses OpCapability Linkage OpCapability Kernel @@ -60,7 +74,7 @@ string kOpenCLMemoryModel32 = R"( OpMemoryModel Physical32 OpenCL )"; -string kOpenCLMemoryModel64 = R"( +std::string kOpenCLMemoryModel64 = R"( OpCapability Addresses OpCapability Linkage OpCapability Kernel @@ -69,7 +83,7 @@ string kOpenCLMemoryModel64 = R"( OpMemoryModel Physical64 OpenCL )"; -string sampledImageSetup = R"( +std::string sampledImageSetup = R"( %void = OpTypeVoid %typeFuncVoid = OpTypeFunction %void %float = OpTypeFloat 32 @@ -94,7 +108,7 @@ string sampledImageSetup = R"( %sampler_inst = OpLoad %sampler_type %s )"; -string BranchConditionalSetup = R"( +std::string BranchConditionalSetup = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -111,6 +125,7 @@ string BranchConditionalSetup = R"( %v4float = OpTypeVector %float 4 ; constants + %true = OpConstantTrue %bool %i0 = OpConstant %int 0 %i1 = OpConstant %int 1 %f0 = OpConstant %float 0 @@ -124,7 +139,7 @@ string BranchConditionalSetup = R"( %lmain = OpLabel )"; -string BranchConditionalTail = R"( +std::string BranchConditionalTail = R"( %target_t = OpLabel OpNop OpBranch %end @@ -141,7 +156,7 @@ string BranchConditionalTail = R"( // TODO: OpUndef TEST_F(ValidateIdWithMessage, OpName) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpName %2 "name" %1 = OpTypeInt 32 0 %2 = OpTypePointer UniformConstant %1 @@ -151,7 +166,7 @@ TEST_F(ValidateIdWithMessage, OpName) { } TEST_F(ValidateIdWithMessage, OpMemberNameGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpMemberName %2 0 "foo" %1 = OpTypeInt 32 0 %2 = OpTypeStruct %1)"; @@ -159,28 +174,30 @@ TEST_F(ValidateIdWithMessage, OpMemberNameGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpMemberNameTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpMemberName %1 0 "foo" %1 = OpTypeInt 32 0)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpMemberName Type '1' is not a struct type.")); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpMemberName Type '1[foo]' is not a struct type.")); } TEST_F(ValidateIdWithMessage, OpMemberNameMemberBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpMemberName %1 1 "foo" %2 = OpTypeInt 32 0 %1 = OpTypeStruct %2)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpMemberName Member '1' index is larger than " - "Type '1's member count.")); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpMemberName Member '1[foo]' index is larger than " + "Type '1[foo]'s member count.")); } TEST_F(ValidateIdWithMessage, OpLineGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpString "/path/to/source.file" OpLine %1 0 0 %2 = OpTypeInt 32 0 @@ -191,7 +208,7 @@ TEST_F(ValidateIdWithMessage, OpLineGood) { } TEST_F(ValidateIdWithMessage, OpLineFileBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 OpLine %1 0 0 )"; @@ -202,7 +219,7 @@ TEST_F(ValidateIdWithMessage, OpLineFileBad) { } TEST_F(ValidateIdWithMessage, OpDecorateGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpDecorate %2 GLSLShared %1 = OpTypeInt 64 0 %2 = OpTypeStruct %1 %1)"; @@ -210,7 +227,7 @@ TEST_F(ValidateIdWithMessage, OpDecorateGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpDecorateBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpDecorate %1 GLSLShared)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); @@ -219,7 +236,7 @@ OpDecorate %1 GLSLShared)"; } TEST_F(ValidateIdWithMessage, OpMemberDecorateGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpMemberDecorate %2 0 Uniform %1 = OpTypeInt 32 0 %2 = OpTypeStruct %1 %1)"; @@ -227,7 +244,7 @@ TEST_F(ValidateIdWithMessage, OpMemberDecorateGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpMemberDecorateBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpMemberDecorate %1 0 Uniform %1 = OpTypeInt 32 0)"; CompileSuccessfully(spirv.c_str()); @@ -238,7 +255,7 @@ TEST_F(ValidateIdWithMessage, OpMemberDecorateBad) { "OpMemberDecorate Structure type '1' is not a struct type.")); } TEST_F(ValidateIdWithMessage, OpMemberDecorateMemberBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpMemberDecorate %1 3 Uniform %int = OpTypeInt 32 0 %1 = OpTypeStruct %int %int)"; @@ -251,7 +268,7 @@ TEST_F(ValidateIdWithMessage, OpMemberDecorateMemberBad) { } TEST_F(ValidateIdWithMessage, OpGroupDecorateGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpDecorationGroup OpDecorate %1 Uniform OpDecorate %1 GLSLShared @@ -263,7 +280,7 @@ TEST_F(ValidateIdWithMessage, OpGroupDecorateGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpDecorationGroupBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpDecorationGroup OpDecorate %1 Uniform OpDecorate %1 GLSLShared @@ -277,7 +294,7 @@ TEST_F(ValidateIdWithMessage, OpDecorationGroupBad) { "OpDecorate, and OpGroupMemberDecorate")); } TEST_F(ValidateIdWithMessage, OpGroupDecorateDecorationGroupBad) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage %1 = OpExtInstImport "GLSL.std.450" @@ -292,7 +309,7 @@ TEST_F(ValidateIdWithMessage, OpGroupDecorateDecorationGroupBad) { "decoration group.")); } TEST_F(ValidateIdWithMessage, OpGroupDecorateTargetBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpDecorationGroup OpDecorate %1 Uniform OpDecorate %1 GLSLShared @@ -304,7 +321,7 @@ TEST_F(ValidateIdWithMessage, OpGroupDecorateTargetBad) { HasSubstr("forward referenced IDs have not been defined")); } TEST_F(ValidateIdWithMessage, OpGroupMemberDecorateDecorationGroupBad) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Linkage %1 = OpExtInstImport "GLSL.std.450" @@ -318,7 +335,7 @@ TEST_F(ValidateIdWithMessage, OpGroupMemberDecorateDecorationGroupBad) { "not a decoration group.")); } TEST_F(ValidateIdWithMessage, OpGroupMemberDecorateIdNotStructBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpDecorationGroup OpGroupMemberDecorate %1 %2 0 %2 = OpTypeInt 32 0)"; @@ -329,7 +346,7 @@ TEST_F(ValidateIdWithMessage, OpGroupMemberDecorateIdNotStructBad) { "a struct type.")); } TEST_F(ValidateIdWithMessage, OpGroupMemberDecorateIndexOutOfBoundBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpDecorate %1 Offset 0 %1 = OpDecorationGroup OpGroupMemberDecorate %1 %struct 3 @@ -347,7 +364,7 @@ TEST_F(ValidateIdWithMessage, OpGroupMemberDecorateIndexOutOfBoundBad) { // TODO: OpExtInst TEST_F(ValidateIdWithMessage, OpEntryPointGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpEntryPoint GLCompute %3 "" %1 = OpTypeVoid %2 = OpTypeFunction %1 @@ -360,7 +377,7 @@ TEST_F(ValidateIdWithMessage, OpEntryPointGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpEntryPointFunctionBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpEntryPoint GLCompute %1 "" %1 = OpTypeVoid)"; CompileSuccessfully(spirv.c_str()); @@ -370,7 +387,7 @@ TEST_F(ValidateIdWithMessage, OpEntryPointFunctionBad) { HasSubstr("OpEntryPoint Entry Point '1' is not a function.")); } TEST_F(ValidateIdWithMessage, OpEntryPointParameterCountBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpEntryPoint GLCompute %3 "" %1 = OpTypeVoid %2 = OpTypeFunction %1 %1 @@ -385,7 +402,7 @@ TEST_F(ValidateIdWithMessage, OpEntryPointParameterCountBad) { "count is not zero")); } TEST_F(ValidateIdWithMessage, OpEntryPointReturnTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpEntryPoint GLCompute %3 "" %1 = OpTypeInt 32 0 %ret = OpConstant %1 0 @@ -402,7 +419,7 @@ TEST_F(ValidateIdWithMessage, OpEntryPointReturnTypeBad) { } TEST_F(ValidateIdWithMessage, OpEntryPointInterfaceIsNotVariableTypeBad) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Geometry OpMemoryModel Logical GLSL450 @@ -426,7 +443,7 @@ TEST_F(ValidateIdWithMessage, OpEntryPointInterfaceIsNotVariableTypeBad) { } TEST_F(ValidateIdWithMessage, OpEntryPointInterfaceStorageClassBad) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader OpCapability Geometry OpMemoryModel Logical GLSL450 @@ -452,7 +469,7 @@ TEST_F(ValidateIdWithMessage, OpEntryPointInterfaceStorageClassBad) { } TEST_F(ValidateIdWithMessage, OpExecutionModeGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpEntryPoint GLCompute %3 "" OpExecutionMode %3 LocalSize 1 1 1 %1 = OpTypeVoid @@ -466,7 +483,7 @@ TEST_F(ValidateIdWithMessage, OpExecutionModeGood) { } TEST_F(ValidateIdWithMessage, OpExecutionModeEntryPointMissing) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpExecutionMode %3 LocalSize 1 1 1 %1 = OpTypeVoid %2 = OpTypeFunction %1 @@ -482,7 +499,7 @@ TEST_F(ValidateIdWithMessage, OpExecutionModeEntryPointMissing) { } TEST_F(ValidateIdWithMessage, OpExecutionModeEntryPointBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpEntryPoint GLCompute %3 "" %a OpExecutionMode %a LocalSize 1 1 1 %void = OpTypeVoid @@ -501,7 +518,7 @@ TEST_F(ValidateIdWithMessage, OpExecutionModeEntryPointBad) { } TEST_F(ValidateIdWithMessage, OpTypeVectorFloat) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4)"; CompileSuccessfully(spirv.c_str()); @@ -509,7 +526,7 @@ TEST_F(ValidateIdWithMessage, OpTypeVectorFloat) { } TEST_F(ValidateIdWithMessage, OpTypeVectorInt) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeVector %1 4)"; CompileSuccessfully(spirv.c_str()); @@ -517,7 +534,7 @@ TEST_F(ValidateIdWithMessage, OpTypeVectorInt) { } TEST_F(ValidateIdWithMessage, OpTypeVectorUInt) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 64 0 %2 = OpTypeVector %1 4)"; CompileSuccessfully(spirv.c_str()); @@ -525,7 +542,7 @@ TEST_F(ValidateIdWithMessage, OpTypeVectorUInt) { } TEST_F(ValidateIdWithMessage, OpTypeVectorBool) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeBool %2 = OpTypeVector %1 4)"; CompileSuccessfully(spirv.c_str()); @@ -533,7 +550,7 @@ TEST_F(ValidateIdWithMessage, OpTypeVectorBool) { } TEST_F(ValidateIdWithMessage, OpTypeVectorComponentTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypePointer UniformConstant %1 %3 = OpTypeVector %2 4)"; @@ -545,7 +562,7 @@ TEST_F(ValidateIdWithMessage, OpTypeVectorComponentTypeBad) { } TEST_F(ValidateIdWithMessage, OpTypeMatrixGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 2 %3 = OpTypeMatrix %2 3)"; @@ -553,7 +570,7 @@ TEST_F(ValidateIdWithMessage, OpTypeMatrixGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpTypeMatrixColumnTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeMatrix %1 3)"; CompileSuccessfully(spirv.c_str()); @@ -564,14 +581,14 @@ TEST_F(ValidateIdWithMessage, OpTypeMatrixColumnTypeBad) { TEST_F(ValidateIdWithMessage, OpTypeSamplerGood) { // In Rev31, OpTypeSampler takes no arguments. - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %s = OpTypeSampler)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpTypeArrayGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpConstant %1 1 %3 = OpTypeArray %1 %2)"; @@ -580,7 +597,7 @@ TEST_F(ValidateIdWithMessage, OpTypeArrayGood) { } TEST_F(ValidateIdWithMessage, OpTypeArrayElementTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpConstant %1 1 %3 = OpTypeArray %2 %2)"; @@ -594,8 +611,9 @@ TEST_F(ValidateIdWithMessage, OpTypeArrayElementTypeBad) { enum Signed { kSigned, kUnsigned }; // Creates an assembly snippet declaring OpTypeArray with the given length. -string MakeArrayLength(const string& len, Signed isSigned, int width) { - ostringstream ss; +std::string MakeArrayLength(const std::string& len, Signed isSigned, + int width) { + std::ostringstream ss; ss << R"( OpCapability Shader OpCapability Linkage @@ -651,7 +669,7 @@ TEST_P(OpTypeArrayLengthTest, LengthPositive) { Val(CompileSuccessfully(MakeArrayLength("55", kSigned, width)))); EXPECT_EQ(SPV_SUCCESS, Val(CompileSuccessfully(MakeArrayLength("55", kUnsigned, width)))); - const string fpad(width / 4 - 1, 'F'); + const std::string fpad(width / 4 - 1, 'F'); EXPECT_EQ( SPV_SUCCESS, Val(CompileSuccessfully(MakeArrayLength("0x7" + fpad, kSigned, width)))); @@ -685,7 +703,7 @@ TEST_P(OpTypeArrayLengthTest, LengthNegative) { SPV_ERROR_INVALID_ID, Val(CompileSuccessfully(MakeArrayLength("-123", kSigned, width)), "OpTypeArray Length '2' default value must be at least 1.")); - const string neg_max = "0x8" + string(width / 4 - 1, '0'); + const std::string neg_max = "0x8" + std::string(width / 4 - 1, '0'); EXPECT_EQ( SPV_ERROR_INVALID_ID, Val(CompileSuccessfully(MakeArrayLength(neg_max, kSigned, width)), @@ -698,10 +716,10 @@ TEST_P(OpTypeArrayLengthTest, LengthNegative) { // here since the purpose of these tests is to check the validity of // OpTypeArray, not OpTypeInt. INSTANTIATE_TEST_CASE_P(Widths, OpTypeArrayLengthTest, - ValuesIn(vector{16, 32, 64})); + ValuesIn(std::vector{16, 32, 64})); TEST_F(ValidateIdWithMessage, OpTypeArrayLengthNull) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %i32 = OpTypeInt 32 0 %len = OpConstantNull %i32 %ary = OpTypeArray %i32 %len)"; @@ -714,7 +732,7 @@ TEST_F(ValidateIdWithMessage, OpTypeArrayLengthNull) { } TEST_F(ValidateIdWithMessage, OpTypeArrayLengthSpecConst) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %i32 = OpTypeInt 32 0 %len = OpSpecConstant %i32 2 %ary = OpTypeArray %i32 %len)"; @@ -723,7 +741,7 @@ TEST_F(ValidateIdWithMessage, OpTypeArrayLengthSpecConst) { } TEST_F(ValidateIdWithMessage, OpTypeArrayLengthSpecConstOp) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %i32 = OpTypeInt 32 0 %c1 = OpConstant %i32 1 %c2 = OpConstant %i32 2 @@ -734,14 +752,14 @@ TEST_F(ValidateIdWithMessage, OpTypeArrayLengthSpecConstOp) { } TEST_F(ValidateIdWithMessage, OpTypeRuntimeArrayGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeRuntimeArray %1)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpTypeRuntimeArrayBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpConstant %1 0 %3 = OpTypeRuntimeArray %2)"; @@ -755,7 +773,7 @@ TEST_F(ValidateIdWithMessage, OpTypeRuntimeArrayBad) { // Unifrom Storage Class TEST_F(ValidateIdWithMessage, OpTypeStructGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeFloat 64 %3 = OpTypePointer Input %1 @@ -764,7 +782,7 @@ TEST_F(ValidateIdWithMessage, OpTypeStructGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpTypeStructMemberTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeFloat 64 %3 = OpConstant %2 0.0 @@ -776,14 +794,14 @@ TEST_F(ValidateIdWithMessage, OpTypeStructMemberTypeBad) { } TEST_F(ValidateIdWithMessage, OpTypePointerGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypePointer Input %1)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpTypePointerBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpConstant %1 0 %3 = OpTypePointer Input %2)"; @@ -794,14 +812,14 @@ TEST_F(ValidateIdWithMessage, OpTypePointerBad) { } TEST_F(ValidateIdWithMessage, OpTypeFunctionGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeFunction %1)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpTypeFunctionReturnTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpConstant %1 0 %3 = OpTypeFunction %2)"; @@ -811,7 +829,7 @@ TEST_F(ValidateIdWithMessage, OpTypeFunctionReturnTypeBad) { HasSubstr("OpTypeFunction Return Type '2' is not a type.")); } TEST_F(ValidateIdWithMessage, OpTypeFunctionParameterBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpConstant %2 0 @@ -824,7 +842,7 @@ TEST_F(ValidateIdWithMessage, OpTypeFunctionParameterBad) { } TEST_F(ValidateIdWithMessage, OpTypePipeGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 16 %3 = OpTypePipe ReadOnly)"; @@ -833,14 +851,14 @@ TEST_F(ValidateIdWithMessage, OpTypePipeGood) { } TEST_F(ValidateIdWithMessage, OpConstantTrueGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeBool %2 = OpConstantTrue %1)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpConstantTrueBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpConstantTrue %1)"; CompileSuccessfully(spirv.c_str()); @@ -851,14 +869,14 @@ TEST_F(ValidateIdWithMessage, OpConstantTrueBad) { } TEST_F(ValidateIdWithMessage, OpConstantFalseGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeBool %2 = OpConstantTrue %1)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpConstantFalseBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpConstantFalse %1)"; CompileSuccessfully(spirv.c_str()); @@ -869,14 +887,14 @@ TEST_F(ValidateIdWithMessage, OpConstantFalseBad) { } TEST_F(ValidateIdWithMessage, OpConstantGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpConstant %1 1)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpConstantBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpConstant !1 !0)"; // The expected failure code is implementation dependent (currently @@ -887,7 +905,7 @@ TEST_F(ValidateIdWithMessage, OpConstantBad) { } TEST_F(ValidateIdWithMessage, OpConstantCompositeVectorGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpConstant %1 3.14 @@ -896,7 +914,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeVectorGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpConstantCompositeVectorWithUndefGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpConstant %1 3.14 @@ -906,7 +924,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeVectorWithUndefGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpConstantCompositeVectorResultTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpConstant %1 3.14 @@ -919,7 +937,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeVectorResultTypeBad) { "OpConstantComposite Result Type '1' is not a composite type.")); } TEST_F(ValidateIdWithMessage, OpConstantCompositeVectorConstituentTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %4 = OpTypeInt 32 0 @@ -935,7 +953,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeVectorConstituentTypeBad) { } TEST_F(ValidateIdWithMessage, OpConstantCompositeVectorConstituentUndefTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %4 = OpTypeInt 32 0 @@ -950,7 +968,7 @@ TEST_F(ValidateIdWithMessage, "Result Type '2's vector element type.")); } TEST_F(ValidateIdWithMessage, OpConstantCompositeMatrixGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpTypeMatrix %2 4 @@ -965,7 +983,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeMatrixGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpConstantCompositeMatrixUndefGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpTypeMatrix %2 4 @@ -980,7 +998,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeMatrixUndefGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpConstantCompositeMatrixConstituentTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %11 = OpTypeVector %1 3 @@ -1001,7 +1019,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeMatrixConstituentTypeBad) { } TEST_F(ValidateIdWithMessage, OpConstantCompositeMatrixConstituentUndefTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %11 = OpTypeVector %1 3 @@ -1020,25 +1038,8 @@ TEST_F(ValidateIdWithMessage, "component count does not match Result Type '4's " "vector component count.")); } -TEST_F(ValidateIdWithMessage, OpConstantCompositeMatrixColumnTypeBad) { - string spirv = kGLSL450MemoryModel + R"( - %1 = OpTypeInt 32 0 - %2 = OpTypeFloat 32 - %3 = OpTypeVector %1 2 - %4 = OpTypeVector %3 2 - %5 = OpTypeMatrix %2 2 - %6 = OpConstant %1 42 - %7 = OpConstant %2 3.14 - %8 = OpConstantComposite %3 %6 %6 - %9 = OpConstantComposite %4 %7 %7 -%10 = OpConstantComposite %5 %8 %9)"; - CompileSuccessfully(spirv.c_str()); - EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("Columns in a matrix must be of type vector.")); -} TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpConstant %1 4 %3 = OpTypeArray %1 %2 @@ -1047,7 +1048,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayWithUndefGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpConstant %1 4 %9 = OpUndef %1 @@ -1057,7 +1058,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayWithUndefGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayConstConstituentBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpConstant %1 4 %3 = OpTypeArray %1 %2 @@ -1069,7 +1070,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayConstConstituentBad) { "constant or undef.")); } TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayConstituentTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpConstant %1 4 %3 = OpTypeArray %1 %2 @@ -1083,7 +1084,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayConstituentTypeBad) { "not match Result Type '3's array element type.")); } TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayConstituentUndefTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpConstant %1 4 %3 = OpTypeArray %1 %2 @@ -1097,7 +1098,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayConstituentUndefTypeBad) { "not match Result Type '3's array element type.")); } TEST_F(ValidateIdWithMessage, OpConstantCompositeStructGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeInt 64 0 %3 = OpTypeStruct %1 %1 %2 @@ -1108,7 +1109,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeStructGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpConstantCompositeStructUndefGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeInt 64 0 %3 = OpTypeStruct %1 %1 %2 @@ -1119,7 +1120,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeStructUndefGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpConstantCompositeStructMemberTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeInt 64 0 %3 = OpTypeStruct %1 %1 %2 @@ -1134,7 +1135,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeStructMemberTypeBad) { } TEST_F(ValidateIdWithMessage, OpConstantCompositeStructMemberUndefTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeInt 64 0 %3 = OpTypeStruct %1 %1 %2 @@ -1149,7 +1150,7 @@ TEST_F(ValidateIdWithMessage, OpConstantCompositeStructMemberUndefTypeBad) { } TEST_F(ValidateIdWithMessage, OpConstantSamplerGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %float = OpTypeFloat 32 %samplerType = OpTypeSampler %3 = OpConstantSampler %samplerType ClampToEdge 0 Nearest)"; @@ -1157,7 +1158,7 @@ TEST_F(ValidateIdWithMessage, OpConstantSamplerGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpConstantSamplerResultTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpConstantSampler %1 Clamp 0 Nearest)"; CompileSuccessfully(spirv.c_str()); @@ -1169,7 +1170,7 @@ TEST_F(ValidateIdWithMessage, OpConstantSamplerResultTypeBad) { } TEST_F(ValidateIdWithMessage, OpConstantNullGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeBool %2 = OpConstantNull %1 %3 = OpTypeInt 32 0 @@ -1205,7 +1206,7 @@ TEST_F(ValidateIdWithMessage, OpConstantNullGood) { } TEST_F(ValidateIdWithMessage, OpConstantNullBasicBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpConstantNull %1)"; CompileSuccessfully(spirv.c_str()); @@ -1217,7 +1218,7 @@ TEST_F(ValidateIdWithMessage, OpConstantNullBasicBad) { } TEST_F(ValidateIdWithMessage, OpConstantNullArrayBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %2 = OpTypeInt 32 0 %3 = OpTypeSampler %4 = OpConstant %2 4 @@ -1232,7 +1233,7 @@ TEST_F(ValidateIdWithMessage, OpConstantNullArrayBad) { } TEST_F(ValidateIdWithMessage, OpConstantNullStructBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %2 = OpTypeSampler %3 = OpTypeStruct %2 %2 %4 = OpConstantNull %3)"; @@ -1245,7 +1246,7 @@ TEST_F(ValidateIdWithMessage, OpConstantNullStructBad) { } TEST_F(ValidateIdWithMessage, OpConstantNullRuntimeArrayBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %bool = OpTypeBool %array = OpTypeRuntimeArray %bool %null = OpConstantNull %array)"; @@ -1258,14 +1259,14 @@ TEST_F(ValidateIdWithMessage, OpConstantNullRuntimeArrayBad) { } TEST_F(ValidateIdWithMessage, OpSpecConstantTrueGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeBool %2 = OpSpecConstantTrue %1)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpSpecConstantTrueBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpSpecConstantTrue %1)"; CompileSuccessfully(spirv.c_str()); @@ -1275,14 +1276,14 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantTrueBad) { } TEST_F(ValidateIdWithMessage, OpSpecConstantFalseGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeBool %2 = OpSpecConstantFalse %1)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpSpecConstantFalseBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpSpecConstantFalse %1)"; CompileSuccessfully(spirv.c_str()); @@ -1292,14 +1293,14 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantFalseBad) { } TEST_F(ValidateIdWithMessage, OpSpecConstantGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpSpecConstant %1 42)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpSpecConstantBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpSpecConstant !1 !4)"; // The expected failure code is implementation dependent (currently @@ -1313,7 +1314,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantBad) { // Valid: SpecConstantComposite specializes to a vector. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpSpecConstant %1 3.14 @@ -1325,7 +1326,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorGood) { // Valid: Vector of floats and Undefs. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorWithUndefGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpSpecConstant %1 3.14 @@ -1338,7 +1339,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorWithUndefGood) { // Invalid: result type is float. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorResultTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpSpecConstant %1 3.14 @@ -1350,7 +1351,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorResultTypeBad) { // Invalid: Vector contains a mix of Int and Float. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorConstituentTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %4 = OpTypeInt 32 0 @@ -1368,7 +1369,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorConstituentTypeBad) { // Invalid: Constituent is not a constant TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorConstituentNotConstantBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpTypeInt 32 0 @@ -1384,7 +1385,7 @@ TEST_F(ValidateIdWithMessage, // Invalid: Vector contains a mix of Undef-int and Float. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorConstituentUndefTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %4 = OpTypeInt 32 0 @@ -1401,7 +1402,7 @@ TEST_F(ValidateIdWithMessage, // Invalid: Vector expects 3 components, but 4 specified. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorNumComponentsBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 3 %3 = OpConstant %1 3.14 @@ -1417,7 +1418,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorNumComponentsBad) { // Valid: 4x4 matrix of floats TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpTypeMatrix %2 4 @@ -1434,7 +1435,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixGood) { // Valid: Matrix in which one column is Undef TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixUndefGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpTypeMatrix %2 4 @@ -1451,7 +1452,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixUndefGood) { // Invalid: Matrix in which the sizes of column vectors are not equal. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixConstituentTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpTypeVector %1 3 @@ -1473,7 +1474,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixConstituentTypeBad) { // Invalid: Matrix type expects 4 columns but only 3 specified. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixNumColsBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpTypeMatrix %2 4 @@ -1494,7 +1495,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixNumColsBad) { // Invalid: Composite contains a non-const/undef component TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixConstituentNotConstBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpConstant %1 0.0 %3 = OpTypeVector %1 4 @@ -1510,7 +1511,7 @@ TEST_F(ValidateIdWithMessage, // Invalid: Composite contains a column that is *not* a vector (it's an array) TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixColTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeInt 32 0 %3 = OpSpecConstant %2 4 @@ -1532,7 +1533,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixColTypeBad) { // Invalid: Matrix with an Undef column of the wrong size. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixConstituentUndefTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeFloat 32 %2 = OpTypeVector %1 4 %3 = OpTypeVector %1 3 @@ -1554,7 +1555,7 @@ TEST_F(ValidateIdWithMessage, // Invalid: Matrix in which some columns are Int and some are Float. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixColumnTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeFloat 32 %3 = OpTypeVector %1 2 @@ -1575,7 +1576,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixColumnTypeBad) { // Valid: Array of integers TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpSpecConstant %1 4 %5 = OpConstant %1 5 @@ -1589,9 +1590,9 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayGood) { // Invalid: Expecting an array of 4 components, but 3 specified. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayNumComponentsBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 -%2 = OpSpecConstant %1 4 +%2 = OpConstant %1 4 %3 = OpTypeArray %1 %2 %4 = OpSpecConstantComposite %3 %2 %2 %2)"; CompileSuccessfully(spirv.c_str()); @@ -1603,7 +1604,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayNumComponentsBad) { // Valid: Array of Integers and Undef-int TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayWithUndefGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpSpecConstant %1 4 %9 = OpUndef %1 @@ -1615,7 +1616,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayWithUndefGood) { // Invalid: Array uses a type as operand. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayConstConstituentBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpConstant %1 4 %3 = OpTypeArray %1 %2 @@ -1629,7 +1630,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayConstConstituentBad) { // Invalid: Array has a mix of Int and Float components. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayConstituentTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpConstant %1 4 %3 = OpTypeArray %1 %2 @@ -1647,7 +1648,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayConstituentTypeBad) { // Invalid: Array has a mix of Int and Undef-float. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayConstituentUndefTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpSpecConstant %1 4 %3 = OpTypeArray %1 %2 @@ -1664,7 +1665,7 @@ TEST_F(ValidateIdWithMessage, // Valid: Struct of {Int32,Int32,Int64}. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeInt 64 0 %3 = OpTypeStruct %1 %1 %2 @@ -1678,7 +1679,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructGood) { // Invalid: missing one int32 struct member. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructMissingComponentBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %3 = OpTypeStruct %1 %1 %1 %4 = OpConstant %1 42 @@ -1694,7 +1695,7 @@ TEST_F(ValidateIdWithMessage, // Valid: Struct uses Undef-int64. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructUndefGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeInt 64 0 %3 = OpTypeStruct %1 %1 %2 @@ -1707,7 +1708,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructUndefGood) { // Invalid: Composite contains non-const/undef component. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructNonConstBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeInt 64 0 %3 = OpTypeStruct %1 %1 %2 @@ -1724,7 +1725,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructNonConstBad) { // Invalid: Struct component type does not match expected specialization type. // Second component was expected to be Int32, but got Int64. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructMemberTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeInt 64 0 %3 = OpTypeStruct %1 %1 %2 @@ -1741,7 +1742,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructMemberTypeBad) { // Invalid: Undef-int64 used when Int32 was expected. TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructMemberUndefTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeInt 64 0 %3 = OpTypeStruct %1 %1 %2 @@ -1759,7 +1760,7 @@ TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructMemberUndefTypeBad) { // TODO: OpSpecConstantOp TEST_F(ValidateIdWithMessage, OpVariableGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypePointer Input %1 %3 = OpVariable %2 Input)"; @@ -1767,7 +1768,7 @@ TEST_F(ValidateIdWithMessage, OpVariableGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpVariableInitializerConstantGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypePointer Input %1 %3 = OpConstant %1 42 @@ -1776,7 +1777,7 @@ TEST_F(ValidateIdWithMessage, OpVariableInitializerConstantGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpVariableInitializerGlobalVariableGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypePointer Uniform %1 %3 = OpVariable %2 Uniform @@ -1787,7 +1788,7 @@ TEST_F(ValidateIdWithMessage, OpVariableInitializerGlobalVariableGood) { } // TODO: Positive test OpVariable with OpConstantNull of OpTypePointer TEST_F(ValidateIdWithMessage, OpVariableResultTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpVariable %1 Input)"; CompileSuccessfully(spirv.c_str()); @@ -1797,7 +1798,7 @@ TEST_F(ValidateIdWithMessage, OpVariableResultTypeBad) { HasSubstr("OpVariable Result Type '1' is not a pointer type.")); } TEST_F(ValidateIdWithMessage, OpVariableInitializerIsTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypePointer Input %1 %3 = OpVariable %2 Input %2)"; @@ -1809,7 +1810,7 @@ TEST_F(ValidateIdWithMessage, OpVariableInitializerIsTypeBad) { } TEST_F(ValidateIdWithMessage, OpVariableInitializerIsFunctionVarBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %int = OpTypeInt 32 0 %ptrint = OpTypePointer Function %int %ptrptrint = OpTypePointer Function %ptrint @@ -1830,7 +1831,7 @@ OpFunctionEnd } TEST_F(ValidateIdWithMessage, OpVariableInitializerIsModuleVarGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %int = OpTypeInt 32 0 %ptrint = OpTypePointer Uniform %int %mvar = OpVariable %ptrint Uniform @@ -1848,7 +1849,7 @@ OpFunctionEnd } TEST_F(ValidateIdWithMessage, OpLoadGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer UniformConstant %2 @@ -2022,7 +2023,7 @@ TEST_F(ValidateIdWithMessage, OpLoadVarPtrOpFunctionCallGood) { } TEST_F(ValidateIdWithMessage, OpLoadResultTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer UniformConstant %2 @@ -2042,7 +2043,7 @@ TEST_F(ValidateIdWithMessage, OpLoadResultTypeBad) { } TEST_F(ValidateIdWithMessage, OpLoadPointerBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer UniformConstant %2 @@ -2062,7 +2063,7 @@ TEST_F(ValidateIdWithMessage, OpLoadPointerBad) { // Disabled as bitcasting type to object is now not valid. TEST_F(ValidateIdWithMessage, DISABLED_OpLoadLogicalPointerBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFloat 32 @@ -2087,7 +2088,7 @@ TEST_F(ValidateIdWithMessage, DISABLED_OpLoadLogicalPointerBad) { } TEST_F(ValidateIdWithMessage, OpStoreGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer Uniform %2 @@ -2103,7 +2104,7 @@ TEST_F(ValidateIdWithMessage, OpStoreGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpStorePointerBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer UniformConstant %2 @@ -2123,7 +2124,7 @@ TEST_F(ValidateIdWithMessage, OpStorePointerBad) { // Disabled as bitcasting type to object is now not valid. TEST_F(ValidateIdWithMessage, DISABLED_OpStoreLogicalPointerBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFloat 32 @@ -2182,7 +2183,7 @@ TEST_F(ValidateIdWithMessage, OpStoreVarPtrGood) { } TEST_F(ValidateIdWithMessage, OpStoreObjectGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer Uniform %2 @@ -2191,16 +2192,17 @@ TEST_F(ValidateIdWithMessage, OpStoreObjectGood) { %6 = OpVariable %3 UniformConstant %7 = OpFunction %1 None %4 %8 = OpLabel - OpStore %6 %7 +%9 = OpUndef %1 + OpStore %6 %9 OpReturn OpFunctionEnd)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpStore Object '7's type is void.")); + HasSubstr("OpStore Object '9's type is void.")); } TEST_F(ValidateIdWithMessage, OpStoreTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %9 = OpTypeFloat 32 @@ -2227,7 +2229,7 @@ TEST_F(ValidateIdWithMessage, OpStoreTypeBad) { // relaxes the rules for them as well. Also need test to check for layout // decorations specific to those types. TEST_F(ValidateIdWithMessage, OpStoreTypeBadStruct) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpMemberDecorate %1 0 Offset 0 OpMemberDecorate %1 1 Offset 4 OpMemberDecorate %2 0 Offset 0 @@ -2256,7 +2258,7 @@ TEST_F(ValidateIdWithMessage, OpStoreTypeBadStruct) { // Same code as the last test. The difference is that we relax the rule. // Because the structs %3 and %5 are defined the same way. TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedStruct) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpMemberDecorate %1 0 Offset 0 OpMemberDecorate %1 1 Offset 4 OpMemberDecorate %2 0 Offset 0 @@ -2283,7 +2285,7 @@ TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedStruct) { // Same code as the last test excect for an extra decoration on one of the // members. With the relaxed rules, the code is still valid. TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedStructWithExtraDecoration) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpMemberDecorate %1 0 Offset 0 OpMemberDecorate %1 1 Offset 4 OpMemberDecorate %1 0 RelaxedPrecision @@ -2311,7 +2313,7 @@ TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedStructWithExtraDecoration) { // This test check that we recursively traverse the struct to check if they are // interchangable. TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedNestedStruct) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpMemberDecorate %1 0 Offset 0 OpMemberDecorate %1 1 Offset 4 OpMemberDecorate %2 0 Offset 0 @@ -2347,7 +2349,7 @@ TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedNestedStruct) { // This test check that the even with the relaxed rules an error is identified // if the members of the struct are in a different order. TEST_F(ValidateIdWithMessage, OpStoreTypeBadRelaxedStruct1) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpMemberDecorate %1 0 Offset 0 OpMemberDecorate %1 1 Offset 4 OpMemberDecorate %2 0 Offset 0 @@ -2387,7 +2389,7 @@ TEST_F(ValidateIdWithMessage, OpStoreTypeBadRelaxedStruct1) { // This test check that the even with the relaxed rules an error is identified // if the members of the struct are at different offsets. TEST_F(ValidateIdWithMessage, OpStoreTypeBadRelaxedStruct2) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpMemberDecorate %1 0 Offset 4 OpMemberDecorate %1 1 Offset 0 OpMemberDecorate %2 0 Offset 0 @@ -2425,7 +2427,7 @@ TEST_F(ValidateIdWithMessage, OpStoreTypeBadRelaxedStruct2) { } TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedLogicalPointerReturnPointer) { - const string spirv = R"( + const std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -2444,7 +2446,7 @@ TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedLogicalPointerReturnPointer) { } TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedLogicalPointerAllocPointer) { - const string spirv = R"( + const std::string spirv = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -2467,7 +2469,7 @@ TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedLogicalPointerAllocPointer) { } TEST_F(ValidateIdWithMessage, OpStoreVoid) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer Uniform %2 @@ -2486,7 +2488,7 @@ TEST_F(ValidateIdWithMessage, OpStoreVoid) { } TEST_F(ValidateIdWithMessage, OpStoreLabel) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer Uniform %2 @@ -2506,7 +2508,7 @@ TEST_F(ValidateIdWithMessage, OpStoreLabel) { // TODO: enable when this bug is fixed: // https://cvs.khronos.org/bugzilla/show_bug.cgi?id=15404 TEST_F(ValidateIdWithMessage, DISABLED_OpStoreFunction) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %2 = OpTypeInt 32 0 %3 = OpTypePointer UniformConstant %2 %4 = OpTypeFunction %2 @@ -2522,7 +2524,7 @@ TEST_F(ValidateIdWithMessage, DISABLED_OpStoreFunction) { } TEST_F(ValidateIdWithMessage, OpStoreBuiltin) { - string spirv = R"( + std::string spirv = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -2560,7 +2562,7 @@ TEST_F(ValidateIdWithMessage, OpStoreBuiltin) { } TEST_F(ValidateIdWithMessage, OpCopyMemoryGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer UniformConstant %2 @@ -2578,8 +2580,51 @@ TEST_F(ValidateIdWithMessage, OpCopyMemoryGood) { CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } + +TEST_F(ValidateIdWithMessage, OpCopyMemoryNonPointerTarget) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %2 +%4 = OpTypeFunction %1 %2 %3 +%5 = OpFunction %1 None %4 +%6 = OpFunctionParameter %2 +%7 = OpFunctionParameter %3 +%8 = OpLabel +OpCopyMemory %6 %7 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Target operand '6' is not a pointer.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemoryNonPointerSource) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %2 +%4 = OpTypeFunction %1 %2 %3 +%5 = OpFunction %1 None %4 +%6 = OpFunctionParameter %2 +%7 = OpFunctionParameter %3 +%8 = OpLabel +OpCopyMemory %7 %6 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Source operand '6' is not a pointer.")); +} + TEST_F(ValidateIdWithMessage, OpCopyMemoryBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer UniformConstant %2 @@ -2598,13 +2643,56 @@ TEST_F(ValidateIdWithMessage, OpCopyMemoryBad) { CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCopyMemory Target '5's type does not match " + HasSubstr("Target '5's type does not match " "Source '2's type.")); } -// TODO: OpCopyMemorySized +TEST_F(ValidateIdWithMessage, OpCopyMemoryVoidTarget) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %1 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFunction %1 %3 %4 +%6 = OpFunction %1 None %5 +%7 = OpFunctionParameter %3 +%8 = OpFunctionParameter %4 +%9 = OpLabel +OpCopyMemory %7 %8 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Target operand '7' cannot be a void pointer.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemoryVoidSource) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %1 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFunction %1 %3 %4 +%6 = OpFunction %1 None %5 +%7 = OpFunctionParameter %3 +%8 = OpFunctionParameter %4 +%9 = OpLabel +OpCopyMemory %8 %7 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Source operand '7' cannot be a void pointer.")); +} + TEST_F(ValidateIdWithMessage, OpCopyMemorySizedGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer UniformConstant %2 @@ -2622,7 +2710,7 @@ TEST_F(ValidateIdWithMessage, OpCopyMemorySizedGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpCopyMemorySizedTargetBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer UniformConstant %2 @@ -2638,10 +2726,10 @@ TEST_F(ValidateIdWithMessage, OpCopyMemorySizedTargetBad) { CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCopyMemorySized Target '9' is not a pointer.")); + HasSubstr("Target operand '9' is not a pointer.")); } TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSourceBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer UniformConstant %2 @@ -2651,16 +2739,16 @@ TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSourceBad) { %7 = OpFunction %1 None %6 %8 = OpLabel %9 = OpVariable %4 Function - OpCopyMemorySized %9 %6 %5 None + OpCopyMemorySized %9 %5 %5 None OpReturn OpFunctionEnd)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCopyMemorySized Source '6' is not a pointer.")); + HasSubstr("Source operand '5' is not a pointer.")); } TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer UniformConstant %2 @@ -2676,12 +2764,12 @@ TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeBad) { OpFunctionEnd)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpCopyMemorySized Size '6's variable type is not " - "an integer type.")); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Size operand '6' must be a scalar integer type.")); } TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypePointer UniformConstant %2 @@ -2701,8 +2789,173 @@ TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeTypeBad) { EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr( - "OpCopyMemorySized Size '9's type is not an integer type.")); + HasSubstr("Size operand '9' must be a scalar integer type.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeConstantNull) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstantNull %2 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Size operand '3' cannot be a constant zero.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeConstantZero) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Size operand '3' cannot be a constant zero.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeConstantZero64) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 64 0 +%3 = OpConstant %2 0 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Size operand '3' cannot be a constant zero.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeConstantNegative) { + const std::string spirv = kNoKernelGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 1 +%3 = OpConstant %2 -1 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Size operand '3' cannot have the sign bit set to 1.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeConstantNegative64) { + const std::string spirv = kNoKernelGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 64 1 +%3 = OpConstant %2 -1 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Size operand '3' cannot have the sign bit set to 1.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeUnsignedNegative) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 2147483648 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeUnsignedNegative64) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 64 0 +%3 = OpConstant %2 9223372036854775808 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } const char kDeeplyNestedStructureSetup[] = R"( @@ -2773,10 +3026,10 @@ bool AccessChainRequiresElemId(const std::string& instr) { TEST_P(AccessChainInstructionTest, AccessChainGood) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + - "%float_entry = " + instr + - R"( %_ptr_Private_float %my_matrix )" + elem + - R"(%int_0 %int_1 + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + + "%float_entry = " + instr + + R"( %_ptr_Private_float %my_matrix )" + elem + + R"(%int_0 %int_1 OpReturn OpFunctionEnd )"; @@ -2788,9 +3041,11 @@ TEST_P(AccessChainInstructionTest, AccessChainGood) { TEST_P(AccessChainInstructionTest, AccessChainResultTypeBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %float_entry = )" + - instr + R"( %float %my_matrix )" + elem + R"(%int_0 %int_1 + instr + + R"( %float %my_matrix )" + elem + + R"(%int_0 %int_1 OpReturn OpFunctionEnd )"; @@ -2807,10 +3062,10 @@ OpFunctionEnd TEST_P(AccessChainInstructionTest, AccessChainBaseTypeVoidBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %float_entry = )" + - instr + " %_ptr_Private_float %void " + elem + - R"(%int_0 %int_1 + instr + " %_ptr_Private_float %void " + elem + + R"(%int_0 %int_1 OpReturn OpFunctionEnd )"; @@ -2826,11 +3081,11 @@ OpFunctionEnd TEST_P(AccessChainInstructionTest, AccessChainBaseTypeNonPtrVariableBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %entry = )" + - instr + - R"( %_ptr_Private_float %_ptr_Private_float )" + elem + - R"(%int_0 %int_1 + instr + R"( %_ptr_Private_float %_ptr_Private_float )" + + elem + + R"(%int_0 %int_1 OpReturn OpFunctionEnd )"; @@ -2847,11 +3102,10 @@ TEST_P(AccessChainInstructionTest, AccessChainResultAndBaseStorageClassDoesntMatchBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %entry = )" + - instr + - R"( %_ptr_Function_float %my_matrix )" + elem + - R"(%int_0 %int_1 + instr + R"( %_ptr_Function_float %my_matrix )" + elem + + R"(%int_0 %int_1 OpReturn OpFunctionEnd )"; @@ -2869,10 +3123,10 @@ TEST_P(AccessChainInstructionTest, AccessChainBasePtrNotPointingToCompositeBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %entry = )" + - instr + - R"( %_ptr_Private_float %my_float_var )" + elem + R"(%int_0 + instr + R"( %_ptr_Private_float %my_float_var )" + elem + + R"(%int_0 OpReturn OpFunctionEnd )"; @@ -2889,10 +3143,10 @@ OpFunctionEnd TEST_P(AccessChainInstructionTest, AccessChainNoIndexesGood) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %entry = )" + - instr + - R"( %_ptr_Private_float %my_float_var )" + elem + R"( + instr + R"( %_ptr_Private_float %my_float_var )" + elem + + R"( OpReturn OpFunctionEnd )"; @@ -2905,10 +3159,10 @@ OpFunctionEnd TEST_P(AccessChainInstructionTest, AccessChainNoIndexesBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %entry = )" + - instr + - R"( %_ptr_Private_mat4x3 %my_float_var )" + elem + R"( + instr + R"( %_ptr_Private_mat4x3 %my_float_var )" + elem + + R"( OpReturn OpFunctionEnd )"; @@ -3056,10 +3310,10 @@ TEST_P(AccessChainInstructionTest, CustomizedAccessChainTooManyIndexesBad) { TEST_P(AccessChainInstructionTest, AccessChainUndefinedIndexBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = - kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( -%entry = )" + instr + - R"( %_ptr_Private_float %my_matrix )" + elem + R"(%float %int_1 + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%entry = )" + + instr + R"( %_ptr_Private_float %my_matrix )" + elem + + R"(%float %int_1 OpReturn OpFunctionEnd )"; @@ -3075,10 +3329,10 @@ OpFunctionEnd TEST_P(AccessChainInstructionTest, AccessChainStructIndexNotConstantBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %f = )" + - instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + - R"(%int_0 %spec_int %int_2 + instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + + R"(%int_0 %spec_int %int_2 OpReturn OpFunctionEnd )"; @@ -3095,11 +3349,10 @@ TEST_P(AccessChainInstructionTest, AccessChainStructResultTypeDoesntMatchIndexedTypeBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %entry = )" + - instr + - R"( %_ptr_Uniform_float %blockName_var )" + elem + - R"(%int_0 %int_1 %int_2 + instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + + R"(%int_0 %int_1 %int_2 OpReturn OpFunctionEnd )"; @@ -3116,11 +3369,10 @@ OpFunctionEnd TEST_P(AccessChainInstructionTest, AccessChainStructTooManyIndexesBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %entry = )" + - instr + - R"( %_ptr_Uniform_float %blockName_var )" + elem + - R"(%int_0 %int_2 %int_2 + instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + + R"(%int_0 %int_2 %int_2 OpReturn OpFunctionEnd )"; @@ -3136,11 +3388,10 @@ OpFunctionEnd TEST_P(AccessChainInstructionTest, AccessChainStructIndexOutOfBoundBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %entry = )" + - instr + - R"( %_ptr_Uniform_float %blockName_var )" + elem + - R"(%int_3 %int_2 %int_2 + instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + + R"(%int_3 %int_2 %int_2 OpReturn OpFunctionEnd )"; @@ -3163,7 +3414,7 @@ TEST_P(AccessChainInstructionTest, AccessChainIndexIntoAllTypesGood) { // 0 will select the element at the index 0 of the vector. (which is a float). const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - ostringstream spirv; + std::ostringstream spirv; spirv << kGLSL450MemoryModel << kDeeplyNestedStructureSetup << std::endl; spirv << "%ss = " << instr << " %_ptr_Uniform_struct_s %blockName_var " << elem << "%int_0" << std::endl; @@ -3187,11 +3438,10 @@ OpFunctionEnd TEST_P(AccessChainInstructionTest, AccessChainIndexIntoRuntimeArrayGood) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %runtime_arr_entry = )" + - instr + - R"( %_ptr_Uniform_float %blockName_var )" + elem + - R"(%int_2 %int_0 + instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + + R"(%int_2 %int_0 OpReturn OpFunctionEnd )"; @@ -3203,11 +3453,10 @@ OpFunctionEnd TEST_P(AccessChainInstructionTest, AccessChainIndexIntoRuntimeArrayBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %runtime_arr_entry = )" + - instr + - R"( %_ptr_Uniform_float %blockName_var )" + elem + - R"(%int_2 %int_0 %int_1 + instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + + R"(%int_2 %int_0 %int_1 OpReturn OpFunctionEnd )"; @@ -3224,11 +3473,10 @@ OpFunctionEnd TEST_P(AccessChainInstructionTest, AccessChainMatrixMoreArgsThanNeededBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %entry = )" + - instr + - R"( %_ptr_Private_float %my_matrix )" + elem + - R"(%int_0 %int_1 %int_0 + instr + R"( %_ptr_Private_float %my_matrix )" + elem + + R"(%int_0 %int_1 %int_0 OpReturn OpFunctionEnd )"; @@ -3245,11 +3493,10 @@ TEST_P(AccessChainInstructionTest, AccessChainResultTypeDoesntMatchIndexedTypeBad) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( %entry = )" + - instr + - R"( %_ptr_Private_mat4x3 %my_matrix )" + elem + - R"(%int_0 %int_1 + instr + R"( %_ptr_Private_mat4x3 %my_matrix )" + elem + + R"(%int_0 %int_1 OpReturn OpFunctionEnd )"; @@ -3273,7 +3520,7 @@ INSTANTIATE_TEST_CASE_P( // TODO: OpGenericPtrMemSemantics TEST_F(ValidateIdWithMessage, OpFunctionGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %1 %2 %2 @@ -3285,7 +3532,7 @@ TEST_F(ValidateIdWithMessage, OpFunctionGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpFunctionResultTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpConstant %2 42 @@ -3298,10 +3545,10 @@ TEST_F(ValidateIdWithMessage, OpFunctionResultTypeBad) { EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("OpFunction Result Type '2' does not match the " - "Function Type '2's return type.")); + "Function Type's return type '1'.")); } TEST_F(ValidateIdWithMessage, OpReturnValueTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeInt 32 0 %2 = OpTypeFloat 32 %3 = OpConstant %2 0 @@ -3317,7 +3564,7 @@ TEST_F(ValidateIdWithMessage, OpReturnValueTypeBad) { "OpFunction's return type.")); } TEST_F(ValidateIdWithMessage, OpFunctionFunctionTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %4 = OpFunction %1 None %2 @@ -3331,8 +3578,24 @@ OpFunctionEnd)"; HasSubstr("OpFunction Function Type '2' is not a function type.")); } +TEST_F(ValidateIdWithMessage, OpFunctionUseBad) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel +OpReturnValue %3 +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Invalid use of function result id 3.")); +} + TEST_F(ValidateIdWithMessage, OpFunctionParameterGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %1 %2 @@ -3345,7 +3608,7 @@ TEST_F(ValidateIdWithMessage, OpFunctionParameterGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpFunctionParameterMultipleGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %1 %2 %2 @@ -3359,7 +3622,7 @@ TEST_F(ValidateIdWithMessage, OpFunctionParameterMultipleGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpFunctionParameterResultTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %1 %2 @@ -3377,7 +3640,7 @@ TEST_F(ValidateIdWithMessage, OpFunctionParameterResultTypeBad) { } TEST_F(ValidateIdWithMessage, OpFunctionCallGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %2 %2 @@ -3399,7 +3662,7 @@ TEST_F(ValidateIdWithMessage, OpFunctionCallGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpFunctionCallResultTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %2 %2 @@ -3425,7 +3688,7 @@ TEST_F(ValidateIdWithMessage, OpFunctionCallResultTypeBad) { "match Function '2's return type.")); } TEST_F(ValidateIdWithMessage, OpFunctionCallFunctionBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %2 %2 @@ -3443,7 +3706,7 @@ TEST_F(ValidateIdWithMessage, OpFunctionCallFunctionBad) { HasSubstr("OpFunctionCall Function '5' is not a function.")); } TEST_F(ValidateIdWithMessage, OpFunctionCallArgumentTypeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %2 %2 @@ -3475,7 +3738,7 @@ TEST_F(ValidateIdWithMessage, OpFunctionCallArgumentTypeBad) { // Valid: OpSampledImage result is used in the same block by // OpImageSampleImplictLod TEST_F(ValidateIdWithMessage, OpSampledImageGood) { - string spirv = kGLSL450MemoryModel + sampledImageSetup + R"( + std::string spirv = kGLSL450MemoryModel + sampledImageSetup + R"( %smpld_img = OpSampledImage %sampled_image_type %image_inst %sampler_inst %si_lod = OpImageSampleImplicitLod %v4float %smpld_img %const_vec_1_1 OpReturn @@ -3487,7 +3750,7 @@ TEST_F(ValidateIdWithMessage, OpSampledImageGood) { // Invalid: OpSampledImage result is defined in one block and used in a // different block. TEST_F(ValidateIdWithMessage, OpSampledImageUsedInDifferentBlockBad) { - string spirv = kGLSL450MemoryModel + sampledImageSetup + R"( + std::string spirv = kGLSL450MemoryModel + sampledImageSetup + R"( %smpld_img = OpSampledImage %sampled_image_type %image_inst %sampler_inst OpBranch %label_2 %label_2 = OpLabel @@ -3513,7 +3776,7 @@ OpFunctionEnd)"; // // Disabled since OpSelect catches this now. TEST_F(ValidateIdWithMessage, DISABLED_OpSampledImageUsedInOpSelectBad) { - string spirv = kGLSL450MemoryModel + sampledImageSetup + R"( + std::string spirv = kGLSL450MemoryModel + sampledImageSetup + R"( %smpld_img = OpSampledImage %sampled_image_type %image_inst %sampler_inst %select_img = OpSelect %sampled_image_type %spec_true %smpld_img %smpld_img OpReturn @@ -3529,7 +3792,7 @@ OpFunctionEnd)"; // Valid: Get a float in a matrix using CompositeExtract. // Valid: Insert float into a matrix using CompositeInsert. TEST_F(ValidateIdWithMessage, CompositeExtractInsertGood) { - ostringstream spirv; + std::ostringstream spirv; spirv << kGLSL450MemoryModel << kDeeplyNestedStructureSetup << std::endl; spirv << "%matrix = OpLoad %mat4x3 %my_matrix" << std::endl; spirv << "%float_entry = OpCompositeExtract %float %matrix 0 1" << std::endl; @@ -3612,7 +3875,7 @@ TEST_F(ValidateIdWithMessage, OpFunctionCallArgumentCountBar) { // TODO: OpVectorInsertDynamic TEST_F(ValidateIdWithMessage, OpVectorShuffleIntGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %int = OpTypeInt 32 0 %ivec3 = OpTypeVector %int 3 %ivec4 = OpTypeVector %int 4 @@ -3635,7 +3898,7 @@ TEST_F(ValidateIdWithMessage, OpVectorShuffleIntGood) { } TEST_F(ValidateIdWithMessage, OpVectorShuffleFloatGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %float = OpTypeFloat 32 %vec2 = OpTypeVector %float 2 %vec3 = OpTypeVector %float 3 @@ -3661,7 +3924,7 @@ TEST_F(ValidateIdWithMessage, OpVectorShuffleFloatGood) { } TEST_F(ValidateIdWithMessage, OpVectorShuffleScalarResultType) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %float = OpTypeFloat 32 %vec2 = OpTypeVector %float 2 %ptr_vec2 = OpTypePointer Function %vec2 @@ -3684,7 +3947,7 @@ TEST_F(ValidateIdWithMessage, OpVectorShuffleScalarResultType) { } TEST_F(ValidateIdWithMessage, OpVectorShuffleComponentCount) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %int = OpTypeInt 32 0 %ivec3 = OpTypeVector %int 3 %ptr_ivec3 = OpTypePointer Function %ivec3 @@ -3709,7 +3972,7 @@ TEST_F(ValidateIdWithMessage, OpVectorShuffleComponentCount) { } TEST_F(ValidateIdWithMessage, OpVectorShuffleVector1Type) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %int = OpTypeInt 32 0 %ivec2 = OpTypeVector %int 2 %ptr_int = OpTypePointer Function %int @@ -3730,7 +3993,7 @@ TEST_F(ValidateIdWithMessage, OpVectorShuffleVector1Type) { } TEST_F(ValidateIdWithMessage, OpVectorShuffleVector2Type) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %int = OpTypeInt 32 0 %ivec2 = OpTypeVector %int 2 %ptr_ivec2 = OpTypePointer Function %ivec2 @@ -3752,7 +4015,7 @@ TEST_F(ValidateIdWithMessage, OpVectorShuffleVector2Type) { } TEST_F(ValidateIdWithMessage, OpVectorShuffleVector1ComponentType) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %int = OpTypeInt 32 0 %ivec3 = OpTypeVector %int 3 %ptr_ivec3 = OpTypePointer Function %ivec3 @@ -3785,7 +4048,7 @@ TEST_F(ValidateIdWithMessage, OpVectorShuffleVector1ComponentType) { } TEST_F(ValidateIdWithMessage, OpVectorShuffleVector2ComponentType) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %int = OpTypeInt 32 0 %ivec3 = OpTypeVector %int 3 %ptr_ivec3 = OpTypePointer Function %ivec3 @@ -3818,7 +4081,7 @@ TEST_F(ValidateIdWithMessage, OpVectorShuffleVector2ComponentType) { } TEST_F(ValidateIdWithMessage, OpVectorShuffleLiterals) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %float = OpTypeFloat 32 %vec2 = OpTypeVector %float 2 %vec3 = OpTypeVector %float 3 @@ -3836,13 +4099,16 @@ TEST_F(ValidateIdWithMessage, OpVectorShuffleLiterals) { %var2 = OpVariable %ptr_vec3 Function %2 %6 = OpLoad %vec2 %var %7 = OpLoad %vec3 %var2 -%8 = OpVectorShuffle %vec4 %6 %7 0 5 2 6 +%8 = OpVectorShuffle %vec4 %6 %7 0 8 2 6 OpReturnValue %8 OpFunctionEnd)"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("Component literal value 5 is greater than 4.")); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Component index 8 is out of bounds for combined (Vector1 + Vector2) " + "size of 5.")); } // TODO: OpCompositeConstruct @@ -3930,7 +4196,7 @@ TEST_F(ValidateIdWithMessage, OpVectorShuffleLiterals) { // TODO: OpBranch TEST_F(ValidateIdWithMessage, OpPhiNotAType) { - string spirv = kOpenCLMemoryModel32 + R"( + std::string spirv = kOpenCLMemoryModel32 + R"( %2 = OpTypeBool %3 = OpConstantTrue %2 %4 = OpTypeVoid @@ -3946,12 +4212,11 @@ OpFunctionEnd CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpPhi's type 3 is not a type instruction.")); + EXPECT_THAT(getDiagnosticString(), HasSubstr("ID 3 is not a type id")); } TEST_F(ValidateIdWithMessage, OpPhiSamePredecessor) { - string spirv = kOpenCLMemoryModel32 + R"( + std::string spirv = kOpenCLMemoryModel32 + R"( %2 = OpTypeBool %3 = OpConstantTrue %2 %4 = OpTypeVoid @@ -3970,7 +4235,7 @@ OpFunctionEnd } TEST_F(ValidateIdWithMessage, OpPhiOddArgumentNumber) { - string spirv = kOpenCLMemoryModel32 + R"( + std::string spirv = kOpenCLMemoryModel32 + R"( %2 = OpTypeBool %3 = OpConstantTrue %2 %4 = OpTypeVoid @@ -3992,7 +4257,7 @@ OpFunctionEnd } TEST_F(ValidateIdWithMessage, OpPhiTooFewPredecessors) { - string spirv = kOpenCLMemoryModel32 + R"( + std::string spirv = kOpenCLMemoryModel32 + R"( %2 = OpTypeBool %3 = OpConstantTrue %2 %4 = OpTypeVoid @@ -4014,7 +4279,7 @@ OpFunctionEnd } TEST_F(ValidateIdWithMessage, OpPhiTooManyPredecessors) { - string spirv = kOpenCLMemoryModel32 + R"( + std::string spirv = kOpenCLMemoryModel32 + R"( %2 = OpTypeBool %3 = OpConstantTrue %2 %4 = OpTypeVoid @@ -4038,7 +4303,7 @@ OpFunctionEnd } TEST_F(ValidateIdWithMessage, OpPhiMismatchedTypes) { - string spirv = kOpenCLMemoryModel32 + R"( + std::string spirv = kOpenCLMemoryModel32 + R"( %2 = OpTypeBool %3 = OpConstantTrue %2 %4 = OpTypeVoid @@ -4064,7 +4329,7 @@ OpFunctionEnd } TEST_F(ValidateIdWithMessage, OpPhiPredecessorNotABlock) { - string spirv = kOpenCLMemoryModel32 + R"( + std::string spirv = kOpenCLMemoryModel32 + R"( %2 = OpTypeBool %3 = OpConstantTrue %2 %4 = OpTypeVoid @@ -4090,7 +4355,7 @@ OpFunctionEnd } TEST_F(ValidateIdWithMessage, OpPhiNotAPredecessor) { - string spirv = kOpenCLMemoryModel32 + R"( + std::string spirv = kOpenCLMemoryModel32 + R"( %2 = OpTypeBool %3 = OpConstantTrue %2 %4 = OpTypeVoid @@ -4116,7 +4381,7 @@ OpFunctionEnd } TEST_F(ValidateIdWithMessage, OpBranchConditionalGood) { - string spirv = BranchConditionalSetup + R"( + std::string spirv = BranchConditionalSetup + R"( %branch_cond = OpINotEqual %bool %i0 %i1 OpSelectionMerge %end None OpBranchConditional %branch_cond %target_t %target_f @@ -4127,7 +4392,7 @@ TEST_F(ValidateIdWithMessage, OpBranchConditionalGood) { } TEST_F(ValidateIdWithMessage, OpBranchConditionalWithWeightsGood) { - string spirv = BranchConditionalSetup + R"( + std::string spirv = BranchConditionalSetup + R"( %branch_cond = OpINotEqual %bool %i0 %i1 OpSelectionMerge %end None OpBranchConditional %branch_cond %target_t %target_f 1 1 @@ -4138,7 +4403,7 @@ TEST_F(ValidateIdWithMessage, OpBranchConditionalWithWeightsGood) { } TEST_F(ValidateIdWithMessage, OpBranchConditional_CondIsScalarInt) { - string spirv = BranchConditionalSetup + R"( + std::string spirv = BranchConditionalSetup + R"( OpSelectionMerge %end None OpBranchConditional %i0 %target_t %target_f )" + BranchConditionalTail; @@ -4152,45 +4417,33 @@ TEST_F(ValidateIdWithMessage, OpBranchConditional_CondIsScalarInt) { } TEST_F(ValidateIdWithMessage, OpBranchConditional_TrueTargetIsNotLabel) { - string spirv = BranchConditionalSetup + R"( + std::string spirv = BranchConditionalSetup + R"( OpSelectionMerge %end None - OpBranchConditional %i0 %i0 %target_f + OpBranchConditional %true %i0 %target_f )" + BranchConditionalTail; CompileSuccessfully(spirv.c_str()); - // EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); - // EXPECT_THAT( - // getDiagnosticString(), - // HasSubstr("The 'True Label' operand for OpBranchConditional must be the - // ID of an OpLabel instruction")); - - // xxxnsubtil: this is actually caught by the ID validation instead - EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("are referenced but not defined in function")); + HasSubstr("The 'True Label' operand for OpBranchConditional must " + "be the ID of an OpLabel instruction")); } TEST_F(ValidateIdWithMessage, OpBranchConditional_FalseTargetIsNotLabel) { - string spirv = BranchConditionalSetup + R"( + std::string spirv = BranchConditionalSetup + R"( OpSelectionMerge %end None - OpBranchConditional %i0 %target_t %i0 + OpBranchConditional %true %target_t %i0 )" + BranchConditionalTail; CompileSuccessfully(spirv.c_str()); - // EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); - // EXPECT_THAT( - // getDiagnosticString(), - // HasSubstr("The 'False Label' operand for OpBranchConditional must be - // the ID of an OpLabel instruction")); - - // xxxnsubtil: this is actually caught by the ID validation - EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("are referenced but not defined in function")); + HasSubstr("The 'False Label' operand for OpBranchConditional " + "must be the ID of an OpLabel instruction")); } TEST_F(ValidateIdWithMessage, OpBranchConditional_NotEnoughWeights) { - string spirv = BranchConditionalSetup + R"( + std::string spirv = BranchConditionalSetup + R"( %branch_cond = OpINotEqual %bool %i0 %i1 OpSelectionMerge %end None OpBranchConditional %branch_cond %target_t %target_f 1 @@ -4204,7 +4457,7 @@ TEST_F(ValidateIdWithMessage, OpBranchConditional_NotEnoughWeights) { } TEST_F(ValidateIdWithMessage, OpBranchConditional_TooManyWeights) { - string spirv = BranchConditionalSetup + R"( + std::string spirv = BranchConditionalSetup + R"( %branch_cond = OpINotEqual %bool %i0 %i1 OpSelectionMerge %end None OpBranchConditional %branch_cond %target_t %target_f 1 2 3 @@ -4220,7 +4473,7 @@ TEST_F(ValidateIdWithMessage, OpBranchConditional_TooManyWeights) { // TODO: OpSwitch TEST_F(ValidateIdWithMessage, OpReturnValueConstantGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %2 @@ -4234,7 +4487,7 @@ TEST_F(ValidateIdWithMessage, OpReturnValueConstantGood) { } TEST_F(ValidateIdWithMessage, OpReturnValueVariableGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 ;10 %3 = OpTypeFunction %2 @@ -4251,7 +4504,7 @@ TEST_F(ValidateIdWithMessage, OpReturnValueVariableGood) { } TEST_F(ValidateIdWithMessage, OpReturnValueExpressionGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %2 @@ -4266,7 +4519,7 @@ TEST_F(ValidateIdWithMessage, OpReturnValueExpressionGood) { } TEST_F(ValidateIdWithMessage, OpReturnValueIsType) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %2 @@ -4282,7 +4535,7 @@ TEST_F(ValidateIdWithMessage, OpReturnValueIsType) { } TEST_F(ValidateIdWithMessage, OpReturnValueIsLabel) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %2 @@ -4298,7 +4551,7 @@ TEST_F(ValidateIdWithMessage, OpReturnValueIsLabel) { } TEST_F(ValidateIdWithMessage, OpReturnValueIsVoid) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %1 @@ -4316,7 +4569,7 @@ TEST_F(ValidateIdWithMessage, OpReturnValueIsVoid) { TEST_F(ValidateIdWithMessage, OpReturnValueIsVariableInPhysical) { // It's valid to return a pointer in a physical addressing model. - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kOpCapabilitySetup + R"( OpMemoryModel Physical32 OpenCL %1 = OpTypeVoid %2 = OpTypeInt 32 0 @@ -4333,7 +4586,7 @@ TEST_F(ValidateIdWithMessage, OpReturnValueIsVariableInPhysical) { TEST_F(ValidateIdWithMessage, OpReturnValueIsVariableInLogical) { // It's invalid to return a pointer in a physical addressing model. - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kOpCapabilitySetup + R"( OpMemoryModel Logical GLSL450 %1 = OpTypeVoid %2 = OpTypeInt 32 0 @@ -4383,7 +4636,7 @@ TEST_F(ValidateIdWithMessage, DISABLED_OpReturnValueVarPtrBad) { // TODO: enable when this bug is fixed: // https://cvs.khronos.org/bugzilla/show_bug.cgi?id=15404 TEST_F(ValidateIdWithMessage, DISABLED_OpReturnValueIsFunction) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeInt 32 0 %3 = OpTypeFunction %2 @@ -4396,7 +4649,7 @@ TEST_F(ValidateIdWithMessage, DISABLED_OpReturnValueIsFunction) { } TEST_F(ValidateIdWithMessage, UndefinedTypeId) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %s = OpTypeStruct %i32 )"; CompileSuccessfully(spirv.c_str()); @@ -4407,7 +4660,7 @@ TEST_F(ValidateIdWithMessage, UndefinedTypeId) { } TEST_F(ValidateIdWithMessage, UndefinedIdScope) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %u32 = OpTypeInt 32 0 %memsem = OpConstant %u32 0 %void = OpTypeVoid @@ -4424,7 +4677,7 @@ TEST_F(ValidateIdWithMessage, UndefinedIdScope) { } TEST_F(ValidateIdWithMessage, UndefinedIdMemSem) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %u32 = OpTypeInt 32 0 %scope = OpConstant %u32 0 %void = OpTypeVoid @@ -4442,7 +4695,7 @@ TEST_F(ValidateIdWithMessage, UndefinedIdMemSem) { TEST_F(ValidateIdWithMessage, KernelOpEntryPointAndOpInBoundsPtrAccessChainGood) { - string spirv = kOpenCLMemoryModel32 + R"( + std::string spirv = kOpenCLMemoryModel32 + R"( OpEntryPoint Kernel %2 "simple_kernel" OpSource OpenCL_C 200000 OpDecorate %3 BuiltIn GlobalInvocationId @@ -4474,7 +4727,7 @@ TEST_F(ValidateIdWithMessage, } TEST_F(ValidateIdWithMessage, OpPtrAccessChainGood) { - string spirv = kOpenCLMemoryModel64 + R"( + std::string spirv = kOpenCLMemoryModel64 + R"( OpEntryPoint Kernel %2 "another_kernel" OpSource OpenCL_C 200000 OpDecorate %3 BuiltIn GlobalInvocationId @@ -4509,7 +4762,7 @@ TEST_F(ValidateIdWithMessage, OpPtrAccessChainGood) { } TEST_F(ValidateIdWithMessage, OpLoadBitcastPointerGood) { - string spirv = kOpenCLMemoryModel64 + R"( + std::string spirv = kOpenCLMemoryModel64 + R"( %2 = OpTypeVoid %3 = OpTypeInt 32 0 %4 = OpTypeFloat 32 @@ -4527,7 +4780,7 @@ TEST_F(ValidateIdWithMessage, OpLoadBitcastPointerGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpLoadBitcastNonPointerBad) { - string spirv = kOpenCLMemoryModel64 + R"( + std::string spirv = kOpenCLMemoryModel64 + R"( %2 = OpTypeVoid %3 = OpTypeInt 32 0 %4 = OpTypeFloat 32 @@ -4548,7 +4801,7 @@ TEST_F(ValidateIdWithMessage, OpLoadBitcastNonPointerBad) { HasSubstr("OpLoad type for pointer '11' is not a pointer type.")); } TEST_F(ValidateIdWithMessage, OpStoreBitcastPointerGood) { - string spirv = kOpenCLMemoryModel64 + R"( + std::string spirv = kOpenCLMemoryModel64 + R"( %2 = OpTypeVoid %3 = OpTypeInt 32 0 %4 = OpTypeFloat 32 @@ -4567,7 +4820,7 @@ TEST_F(ValidateIdWithMessage, OpStoreBitcastPointerGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateIdWithMessage, OpStoreBitcastNonPointerBad) { - string spirv = kOpenCLMemoryModel64 + R"( + std::string spirv = kOpenCLMemoryModel64 + R"( %2 = OpTypeVoid %3 = OpTypeInt 32 0 %4 = OpTypeFloat 32 @@ -4591,7 +4844,7 @@ TEST_F(ValidateIdWithMessage, OpStoreBitcastNonPointerBad) { // Result resulting from an instruction within a function may not be used // outside that function. TEST_F(ValidateIdWithMessage, ResultIdUsedOutsideOfFunctionBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( %1 = OpTypeVoid %2 = OpTypeFunction %1 %3 = OpTypeInt 32 0 @@ -4616,27 +4869,26 @@ OpFunctionEnd } TEST_F(ValidateIdWithMessage, SpecIdTargetNotSpecializationConstant) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpDecorate %1 SpecId 200 %void = OpTypeVoid %2 = OpTypeFunction %void %int = OpTypeInt 32 0 %1 = OpConstant %int 3 -%main = OpFunction %1 None %2 +%main = OpFunction %void None %2 %4 = OpLabel OpReturnValue %1 OpFunctionEnd )"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("OpDecorate SpectId decoration target '1' is not a " - "scalar specialization constant.")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpDecorate SpecId decoration target '1' is not a " + "scalar specialization constant.")); } TEST_F(ValidateIdWithMessage, SpecIdTargetOpSpecConstantOpBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpDecorate %1 SpecId 200 %void = OpTypeVoid %2 = OpTypeFunction %void @@ -4644,42 +4896,40 @@ OpDecorate %1 SpecId 200 %3 = OpConstant %int 1 %4 = OpConstant %int 2 %1 = OpSpecConstantOp %int IAdd %3 %4 -%main = OpFunction %1 None %2 +%main = OpFunction %void None %2 %6 = OpLabel OpReturnValue %3 OpFunctionEnd )"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("OpDecorate SpectId decoration target '1' is not a " - "scalar specialization constant.")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpDecorate SpecId decoration target '1' is not a " + "scalar specialization constant.")); } TEST_F(ValidateIdWithMessage, SpecIdTargetOpSpecConstantCompositeBad) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpDecorate %1 SpecId 200 %void = OpTypeVoid %2 = OpTypeFunction %void %int = OpTypeInt 32 0 %3 = OpConstant %int 1 %1 = OpSpecConstantComposite %int -%main = OpFunction %1 None %2 +%main = OpFunction %void None %2 %4 = OpLabel OpReturnValue %3 OpFunctionEnd )"; CompileSuccessfully(spirv.c_str()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("OpDecorate SpectId decoration target '1' is not a " - "scalar specialization constant.")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpDecorate SpecId decoration target '1' is not a " + "scalar specialization constant.")); } TEST_F(ValidateIdWithMessage, SpecIdTargetGood) { - string spirv = kGLSL450MemoryModel + R"( + std::string spirv = kGLSL450MemoryModel + R"( OpDecorate %3 SpecId 200 OpDecorate %4 SpecId 201 OpDecorate %5 SpecId 202 @@ -4699,6 +4949,79 @@ OpFunctionEnd EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); } +TEST_F(ValidateIdWithMessage, CorrectErrorForShuffle) { + std::string spirv = kGLSL450MemoryModel + R"( + %uint = OpTypeInt 32 0 + %float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%v2float = OpTypeVector %float 2 + %void = OpTypeVoid + %548 = OpTypeFunction %void + %CS = OpFunction %void None %548 + %550 = OpLabel + %6275 = OpUndef %v2float + %6280 = OpUndef %v2float + %6282 = OpVectorShuffle %v4float %6275 %6280 0 1 4 5 + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Component index 4 is out of bounds for combined (Vector1 + Vector2) " + "size of 4.")); + EXPECT_EQ(23, getErrorPosition().index); +} + +TEST_F(ValidateIdWithMessage, VoidStructMember) { + const std::string spirv = kGLSL450MemoryModel + R"( +%void = OpTypeVoid +%struct = OpTypeStruct %void +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Structures cannot contain a void type.")); +} + +TEST_F(ValidateIdWithMessage, TypeFunctionBadUse) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypePointer Function %2 +%4 = OpFunction %1 None %2 +%5 = OpLabel + OpReturn + OpFunctionEnd)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Invalid use of function type result id 2.")); +} + +TEST_F(ValidateIdWithMessage, BadTypeId) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeVoid + %2 = OpTypeFunction %1 + %3 = OpTypeFloat 32 + %4 = OpConstant %3 0 + %5 = OpFunction %1 None %2 + %6 = OpLabel + %7 = OpUndef %4 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("ID 4 is not a type id")); +} + // TODO: OpLifetimeStart // TODO: OpLifetimeStop // TODO: OpAtomicInit @@ -4763,4 +5086,6 @@ OpFunctionEnd // TODO: OpGroupCommitReadPipe // TODO: OpGroupCommitWritePipe -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_image_test.cpp b/3rdparty/spirv-tools/test/val/val_image_test.cpp index bad2a4349..03f3eeb99 100644 --- a/3rdparty/spirv-tools/test/val/val_image_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_image_test.cpp @@ -18,9 +18,11 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::HasSubstr; @@ -337,8 +339,8 @@ OpFunctionEnd)"; return ss.str(); } -std::string GetShaderHeader( - const std::string& capabilities_and_extensions = "") { +std::string GetShaderHeader(const std::string& capabilities_and_extensions = "", + bool include_entry_point = true) { std::ostringstream ss; ss << R"( OpCapability Shader @@ -346,10 +348,18 @@ OpCapability Int64 )"; ss << capabilities_and_extensions; + if (!include_entry_point) { + ss << "OpCapability Linkage"; + } ss << R"( OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %main "main" +)"; + + if (include_entry_point) { + ss << "OpEntryPoint Fragment %main \"main\""; + } + ss << R"( %void = OpTypeVoid %func = OpTypeFunction %void %bool = OpTypeBool @@ -363,14 +373,14 @@ OpEntryPoint Fragment %main "main" } TEST_F(ValidateImage, TypeImageWrongSampledType) { - const std::string code = GetShaderHeader() + R"( + const std::string code = GetShaderHeader("", false) + R"( %img_type = OpTypeImage %bool 2D 0 0 0 1 Unknown )"; CompileSuccessfully(code.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("TypeImage: expected Sampled Type to be either void or " + HasSubstr("Expected Sampled Type to be either void or " "numerical scalar " "type")); } @@ -378,107 +388,118 @@ TEST_F(ValidateImage, TypeImageWrongSampledType) { TEST_F(ValidateImage, TypeImageVoidSampledTypeVulkan) { const std::string code = GetShaderHeader() + R"( %img_type = OpTypeImage %void 2D 0 0 0 1 Unknown +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +OpReturn +OpFunctionEnd )"; const spv_target_env env = SPV_ENV_VULKAN_1_0; CompileSuccessfully(code, env); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(env)); EXPECT_THAT(getDiagnosticString(), - HasSubstr("TypeImage: expected Sampled Type to be a 32-bit int " + HasSubstr("Expected Sampled Type to be a 32-bit int " "or float scalar type for Vulkan environment")); } TEST_F(ValidateImage, TypeImageU64SampledTypeVulkan) { const std::string code = GetShaderHeader() + R"( %img_type = OpTypeImage %u64 2D 0 0 0 1 Unknown +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +OpReturn +OpFunctionEnd )"; const spv_target_env env = SPV_ENV_VULKAN_1_0; CompileSuccessfully(code, env); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(env)); EXPECT_THAT(getDiagnosticString(), - HasSubstr("TypeImage: expected Sampled Type to be a 32-bit int " + HasSubstr("Expected Sampled Type to be a 32-bit int " "or float scalar type for Vulkan environment")); } TEST_F(ValidateImage, TypeImageWrongDepth) { - const std::string code = GetShaderHeader() + R"( + const std::string code = GetShaderHeader("", false) + R"( %img_type = OpTypeImage %f32 2D 3 0 0 1 Unknown )"; CompileSuccessfully(code.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("TypeImage: invalid Depth 3 (must be 0, 1 or 2)")); + HasSubstr("Invalid Depth 3 (must be 0, 1 or 2)")); } TEST_F(ValidateImage, TypeImageWrongArrayed) { - const std::string code = GetShaderHeader() + R"( + const std::string code = GetShaderHeader("", false) + R"( %img_type = OpTypeImage %f32 2D 0 2 0 1 Unknown )"; CompileSuccessfully(code.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("TypeImage: invalid Arrayed 2 (must be 0 or 1)")); + HasSubstr("Invalid Arrayed 2 (must be 0 or 1)")); } TEST_F(ValidateImage, TypeImageWrongMS) { - const std::string code = GetShaderHeader() + R"( + const std::string code = GetShaderHeader("", false) + R"( %img_type = OpTypeImage %f32 2D 0 0 2 1 Unknown )"; CompileSuccessfully(code.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("TypeImage: invalid MS 2 (must be 0 or 1)")); + HasSubstr("Invalid MS 2 (must be 0 or 1)")); } TEST_F(ValidateImage, TypeImageWrongSampled) { - const std::string code = GetShaderHeader() + R"( + const std::string code = GetShaderHeader("", false) + R"( %img_type = OpTypeImage %f32 2D 0 0 0 3 Unknown )"; CompileSuccessfully(code.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("TypeImage: invalid Sampled 3 (must be 0, 1 or 2)")); + HasSubstr("Invalid Sampled 3 (must be 0, 1 or 2)")); } TEST_F(ValidateImage, TypeImageWrongSampledForSubpassData) { - const std::string code = GetShaderHeader("OpCapability InputAttachment\n") + - R"( + const std::string code = + GetShaderHeader("OpCapability InputAttachment\n", false) + + R"( %img_type = OpTypeImage %f32 SubpassData 0 0 0 1 Unknown )"; CompileSuccessfully(code.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("TypeImage: Dim SubpassData requires Sampled to be 2")); + HasSubstr("Dim SubpassData requires Sampled to be 2")); } TEST_F(ValidateImage, TypeImageWrongFormatForSubpassData) { - const std::string code = GetShaderHeader("OpCapability InputAttachment\n") + - R"( + const std::string code = + GetShaderHeader("OpCapability InputAttachment\n", false) + + R"( %img_type = OpTypeImage %f32 SubpassData 0 0 0 2 Rgba32f )"; CompileSuccessfully(code.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("TypeImage: Dim SubpassData requires format Unknown")); + HasSubstr("Dim SubpassData requires format Unknown")); } TEST_F(ValidateImage, TypeSampledImageNotImage) { - const std::string code = GetShaderHeader() + R"( + const std::string code = GetShaderHeader("", false) + R"( %simg_type = OpTypeSampledImage %f32 )"; CompileSuccessfully(code.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("TypeSampledImage: expected Image to be of type OpTypeImage")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); } TEST_F(ValidateImage, SampledImageSuccess) { @@ -513,9 +534,8 @@ TEST_F(ValidateImage, SampledImageWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Result Type to be OpTypeSampledImage: SampledImage")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be OpTypeSampledImage")); } TEST_F(ValidateImage, SampledImageNotImage) { @@ -528,9 +548,8 @@ TEST_F(ValidateImage, SampledImageNotImage) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Image to be of type OpTypeImage: SampledImage")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); } TEST_F(ValidateImage, SampledImageImageNotForSampling) { @@ -542,10 +561,8 @@ TEST_F(ValidateImage, SampledImageImageNotForSampling) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr( - "Expected Image 'Sampled' parameter to be 0 or 1: SampledImage")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled' parameter to be 0 or 1")); } TEST_F(ValidateImage, SampledImageVulkanUnknownSampled) { @@ -559,8 +576,8 @@ TEST_F(ValidateImage, SampledImageVulkanUnknownSampled) { CompileSuccessfully(GenerateShaderCode(body, "", "Fragment", env), env); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(env)); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Image 'Sampled' parameter to be 1 for Vulkan " - "environment: SampledImage")); + HasSubstr("Expected Image 'Sampled' parameter to " + "be 1 for Vulkan environment.")); } TEST_F(ValidateImage, SampledImageNotSampler) { @@ -572,9 +589,8 @@ TEST_F(ValidateImage, SampledImageNotSampler) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Sampler to be of type OpTypeSampler: SampledImage")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Sampler to be of type OpTypeSampler")); } TEST_F(ValidateImage, SampleImplicitLodSuccess) { @@ -605,8 +621,7 @@ TEST_F(ValidateImage, SampleImplicitLodWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to be int or float vector type: " - "ImageSampleImplicitLod")); + HasSubstr("Expected Result Type to be int or float vector type")); } TEST_F(ValidateImage, SampleImplicitLodWrongNumComponentsResultType) { @@ -620,8 +635,7 @@ TEST_F(ValidateImage, SampleImplicitLodWrongNumComponentsResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to have 4 components: " - "ImageSampleImplicitLod")); + HasSubstr("Expected Result Type to have 4 components")); } TEST_F(ValidateImage, SampleImplicitLodNotSampledImage) { @@ -634,8 +648,7 @@ TEST_F(ValidateImage, SampleImplicitLodNotSampledImage) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage: " - "ImageSampleImplicitLod")); + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); } TEST_F(ValidateImage, SampleImplicitLodWrongSampledType) { @@ -650,8 +663,7 @@ TEST_F(ValidateImage, SampleImplicitLodWrongSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image 'Sampled Type' to be the same as " - "Result Type components: " - "ImageSampleImplicitLod")); + "Result Type components")); } TEST_F(ValidateImage, SampleImplicitLodVoidSampledType) { @@ -677,8 +689,7 @@ TEST_F(ValidateImage, SampleImplicitLodWrongCoordinateType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Coordinate to be float scalar or vector: " - "ImageSampleImplicitLod")); + HasSubstr("Expected Coordinate to be float scalar or vector")); } TEST_F(ValidateImage, SampleImplicitLodCoordinateSizeTooSmall) { @@ -693,8 +704,7 @@ TEST_F(ValidateImage, SampleImplicitLodCoordinateSizeTooSmall) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Coordinate to have at least 2 components, " - "but given only 1: " - "ImageSampleImplicitLod")); + "but given only 1")); } TEST_F(ValidateImage, SampleExplicitLodSuccessShader) { @@ -752,8 +762,7 @@ TEST_F(ValidateImage, SampleExplicitLodWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to be int or float vector type: " - "ImageSampleExplicitLod")); + HasSubstr("Expected Result Type to be int or float vector type")); } TEST_F(ValidateImage, SampleExplicitLodWrongNumComponentsResultType) { @@ -767,8 +776,7 @@ TEST_F(ValidateImage, SampleExplicitLodWrongNumComponentsResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to have 4 components: " - "ImageSampleExplicitLod")); + HasSubstr("Expected Result Type to have 4 components")); } TEST_F(ValidateImage, SampleExplicitLodNotSampledImage) { @@ -781,8 +789,7 @@ TEST_F(ValidateImage, SampleExplicitLodNotSampledImage) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage: " - "ImageSampleExplicitLod")); + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); } TEST_F(ValidateImage, SampleExplicitLodWrongSampledType) { @@ -797,8 +804,7 @@ TEST_F(ValidateImage, SampleExplicitLodWrongSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image 'Sampled Type' to be the same as " - "Result Type components: " - "ImageSampleExplicitLod")); + "Result Type components")); } TEST_F(ValidateImage, SampleExplicitLodVoidSampledType) { @@ -824,8 +830,7 @@ TEST_F(ValidateImage, SampleExplicitLodWrongCoordinateType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Coordinate to be float scalar or vector: " - "ImageSampleExplicitLod")); + HasSubstr("Expected Coordinate to be float scalar or vector")); } TEST_F(ValidateImage, SampleExplicitLodCoordinateSizeTooSmall) { @@ -840,8 +845,7 @@ TEST_F(ValidateImage, SampleExplicitLodCoordinateSizeTooSmall) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Coordinate to have at least 2 components, " - "but given only 1: " - "ImageSampleExplicitLod")); + "but given only 1")); } TEST_F(ValidateImage, SampleExplicitLodBias) { @@ -856,8 +860,8 @@ TEST_F(ValidateImage, SampleExplicitLodBias) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Image Operand Bias can only be used with ImplicitLod opcodes: " - "ImageSampleExplicitLod")); + HasSubstr( + "Image Operand Bias can only be used with ImplicitLod opcodes")); } TEST_F(ValidateImage, LodAndGrad) { @@ -873,8 +877,7 @@ TEST_F(ValidateImage, LodAndGrad) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "Image Operand bits Lod and Grad cannot be set at the same time: " - "ImageSampleExplicitLod")); + "Image Operand bits Lod and Grad cannot be set at the same time")); } TEST_F(ValidateImage, ImplicitLodWithLod) { @@ -890,7 +893,7 @@ TEST_F(ValidateImage, ImplicitLodWithLod) { EXPECT_THAT( getDiagnosticString(), HasSubstr("Image Operand Lod can only be used with ExplicitLod opcodes " - "and OpImageFetch: ImageSampleImplicitLod")); + "and OpImageFetch")); } TEST_F(ValidateImage, LodWrongType) { @@ -904,7 +907,7 @@ TEST_F(ValidateImage, LodWrongType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image Operand Lod to be float scalar when " - "used with ExplicitLod: ImageSampleExplicitLod")); + "used with ExplicitLod")); } TEST_F(ValidateImage, LodWrongDim) { @@ -918,8 +921,7 @@ TEST_F(ValidateImage, LodWrongDim) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Image Operand Lod requires 'Dim' parameter to be 1D, " - "2D, 3D or Cube: " - "ImageSampleExplicitLod")); + "2D, 3D or Cube")); } TEST_F(ValidateImage, LodMultisampled) { @@ -932,8 +934,7 @@ TEST_F(ValidateImage, LodMultisampled) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Image Operand Lod requires 'MS' parameter to be 0: " - "ImageSampleExplicitLod")); + HasSubstr("Image Operand Lod requires 'MS' parameter to be 0")); } TEST_F(ValidateImage, MinLodIncompatible) { @@ -949,7 +950,7 @@ TEST_F(ValidateImage, MinLodIncompatible) { getDiagnosticString(), HasSubstr( "Image Operand MinLod can only be used with ImplicitLod opcodes or " - "together with Image Operand Grad: ImageSampleExplicitLod")); + "together with Image Operand Grad")); } TEST_F(ValidateImage, ImplicitLodWithGrad) { @@ -964,8 +965,8 @@ TEST_F(ValidateImage, ImplicitLodWithGrad) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Image Operand Grad can only be used with ExplicitLod opcodes: " - "ImageSampleImplicitLod")); + HasSubstr( + "Image Operand Grad can only be used with ExplicitLod opcodes")); } TEST_F(ValidateImage, SampleImplicitLod3DArrayedMultisampledSuccess) { @@ -1008,8 +1009,7 @@ TEST_F(ValidateImage, SampleImplicitLodBiasWrongType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Image Operand Bias to be float scalar: " - "ImageSampleImplicitLod")); + HasSubstr("Expected Image Operand Bias to be float scalar")); } TEST_F(ValidateImage, SampleImplicitLodBiasWrongDim) { @@ -1024,8 +1024,7 @@ TEST_F(ValidateImage, SampleImplicitLodBiasWrongDim) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Image Operand Bias requires 'Dim' parameter to be 1D, " - "2D, 3D or Cube: " - "ImageSampleImplicitLod")); + "2D, 3D or Cube")); } TEST_F(ValidateImage, SampleImplicitLodBiasMultisampled) { @@ -1039,8 +1038,7 @@ TEST_F(ValidateImage, SampleImplicitLodBiasMultisampled) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Image Operand Bias requires 'MS' parameter to be 0: " - "ImageSampleImplicitLod")); + HasSubstr("Image Operand Bias requires 'MS' parameter to be 0")); } TEST_F(ValidateImage, SampleExplicitLodGradDxWrongType) { @@ -1055,8 +1053,7 @@ TEST_F(ValidateImage, SampleExplicitLodGradDxWrongType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected both Image Operand Grad ids to be float " - "scalars or vectors: " - "ImageSampleExplicitLod")); + "scalars or vectors")); } TEST_F(ValidateImage, SampleExplicitLodGradDyWrongType) { @@ -1071,8 +1068,7 @@ TEST_F(ValidateImage, SampleExplicitLodGradDyWrongType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected both Image Operand Grad ids to be float " - "scalars or vectors: " - "ImageSampleExplicitLod")); + "scalars or vectors")); } TEST_F(ValidateImage, SampleExplicitLodGradDxWrongSize) { @@ -1088,8 +1084,7 @@ TEST_F(ValidateImage, SampleExplicitLodGradDxWrongSize) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "Expected Image Operand Grad dx to have 3 components, but given 2: " - "ImageSampleExplicitLod")); + "Expected Image Operand Grad dx to have 3 components, but given 2")); } TEST_F(ValidateImage, SampleExplicitLodGradDyWrongSize) { @@ -1105,8 +1100,7 @@ TEST_F(ValidateImage, SampleExplicitLodGradDyWrongSize) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "Expected Image Operand Grad dy to have 3 components, but given 2: " - "ImageSampleExplicitLod")); + "Expected Image Operand Grad dy to have 3 components, but given 2")); } TEST_F(ValidateImage, SampleExplicitLodGradMultisampled) { @@ -1120,8 +1114,7 @@ TEST_F(ValidateImage, SampleExplicitLodGradMultisampled) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Image Operand Grad requires 'MS' parameter to be 0: " - "ImageSampleExplicitLod")); + HasSubstr("Image Operand Grad requires 'MS' parameter to be 0")); } TEST_F(ValidateImage, SampleImplicitLodConstOffsetCubeDim) { @@ -1137,8 +1130,7 @@ TEST_F(ValidateImage, SampleImplicitLodConstOffsetCubeDim) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "Image Operand ConstOffset cannot be used with Cube Image 'Dim': " - "ImageSampleImplicitLod")); + "Image Operand ConstOffset cannot be used with Cube Image 'Dim'")); } TEST_F(ValidateImage, SampleImplicitLodConstOffsetWrongType) { @@ -1154,8 +1146,7 @@ TEST_F(ValidateImage, SampleImplicitLodConstOffsetWrongType) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "Expected Image Operand ConstOffset to be int scalar or vector: " - "ImageSampleImplicitLod")); + "Expected Image Operand ConstOffset to be int scalar or vector")); } TEST_F(ValidateImage, SampleImplicitLodConstOffsetWrongSize) { @@ -1170,8 +1161,7 @@ TEST_F(ValidateImage, SampleImplicitLodConstOffsetWrongSize) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image Operand ConstOffset to have 3 " - "components, but given 2: " - "ImageSampleImplicitLod")); + "components, but given 2")); } TEST_F(ValidateImage, SampleImplicitLodConstOffsetNotConst) { @@ -1187,8 +1177,7 @@ TEST_F(ValidateImage, SampleImplicitLodConstOffsetNotConst) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image Operand ConstOffset to be a const object: " - "ImageSampleImplicitLod")); + HasSubstr("Expected Image Operand ConstOffset to be a const object")); } TEST_F(ValidateImage, SampleImplicitLodOffsetCubeDim) { @@ -1203,8 +1192,7 @@ TEST_F(ValidateImage, SampleImplicitLodOffsetCubeDim) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Image Operand Offset cannot be used with Cube Image 'Dim': " - "ImageSampleImplicitLod")); + HasSubstr("Image Operand Offset cannot be used with Cube Image 'Dim'")); } TEST_F(ValidateImage, SampleImplicitLodOffsetWrongType) { @@ -1219,8 +1207,7 @@ TEST_F(ValidateImage, SampleImplicitLodOffsetWrongType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image Operand Offset to be int scalar or vector: " - "ImageSampleImplicitLod")); + HasSubstr("Expected Image Operand Offset to be int scalar or vector")); } TEST_F(ValidateImage, SampleImplicitLodOffsetWrongSize) { @@ -1236,8 +1223,7 @@ TEST_F(ValidateImage, SampleImplicitLodOffsetWrongSize) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "Expected Image Operand Offset to have 3 components, but given 2: " - "ImageSampleImplicitLod")); + "Expected Image Operand Offset to have 3 components, but given 2")); } TEST_F(ValidateImage, SampleImplicitLodMoreThanOneOffset) { @@ -1252,8 +1238,7 @@ TEST_F(ValidateImage, SampleImplicitLodMoreThanOneOffset) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Image Operands Offset, ConstOffset, ConstOffsets " - "cannot be used together: " - "ImageSampleImplicitLod")); + "cannot be used together")); } TEST_F(ValidateImage, SampleImplicitLodMinLodWrongType) { @@ -1267,8 +1252,7 @@ TEST_F(ValidateImage, SampleImplicitLodMinLodWrongType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Image Operand MinLod to be float scalar: " - "ImageSampleImplicitLod")); + HasSubstr("Expected Image Operand MinLod to be float scalar")); } TEST_F(ValidateImage, SampleImplicitLodMinLodWrongDim) { @@ -1283,8 +1267,7 @@ TEST_F(ValidateImage, SampleImplicitLodMinLodWrongDim) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Image Operand MinLod requires 'Dim' parameter to be " - "1D, 2D, 3D or Cube: " - "ImageSampleImplicitLod")); + "1D, 2D, 3D or Cube")); } TEST_F(ValidateImage, SampleImplicitLodMinLodMultisampled) { @@ -1297,9 +1280,9 @@ TEST_F(ValidateImage, SampleImplicitLodMinLodMultisampled) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("Image Operand MinLod requires 'MS' parameter to be 0: " - "ImageSampleImplicitLod")); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Image Operand MinLod requires 'MS' parameter to be 0")); } TEST_F(ValidateImage, SampleProjExplicitLodSuccess2D) { @@ -1342,8 +1325,7 @@ TEST_F(ValidateImage, SampleProjExplicitLodWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to be int or float vector type: " - "ImageSampleProjExplicitLod")); + HasSubstr("Expected Result Type to be int or float vector type")); } TEST_F(ValidateImage, SampleProjExplicitLodWrongNumComponentsResultType) { @@ -1357,8 +1339,7 @@ TEST_F(ValidateImage, SampleProjExplicitLodWrongNumComponentsResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to have 4 components: " - "ImageSampleProjExplicitLod")); + HasSubstr("Expected Result Type to have 4 components")); } TEST_F(ValidateImage, SampleProjExplicitLodNotSampledImage) { @@ -1371,8 +1352,7 @@ TEST_F(ValidateImage, SampleProjExplicitLodNotSampledImage) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage: " - "ImageSampleProjExplicitLod")); + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); } TEST_F(ValidateImage, SampleProjExplicitLodWrongSampledType) { @@ -1387,8 +1367,7 @@ TEST_F(ValidateImage, SampleProjExplicitLodWrongSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image 'Sampled Type' to be the same as " - "Result Type components: " - "ImageSampleProjExplicitLod")); + "Result Type components")); } TEST_F(ValidateImage, SampleProjExplicitLodVoidSampledType) { @@ -1414,8 +1393,7 @@ TEST_F(ValidateImage, SampleProjExplicitLodWrongCoordinateType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Coordinate to be float scalar or vector: " - "ImageSampleProjExplicitLod")); + HasSubstr("Expected Coordinate to be float scalar or vector")); } TEST_F(ValidateImage, SampleProjExplicitLodCoordinateSizeTooSmall) { @@ -1430,8 +1408,7 @@ TEST_F(ValidateImage, SampleProjExplicitLodCoordinateSizeTooSmall) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Coordinate to have at least 3 components, " - "but given only 2: " - "ImageSampleProjExplicitLod")); + "but given only 2")); } TEST_F(ValidateImage, SampleProjImplicitLodSuccess) { @@ -1462,8 +1439,7 @@ TEST_F(ValidateImage, SampleProjImplicitLodWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to be int or float vector type: " - "ImageSampleProjImplicitLod")); + HasSubstr("Expected Result Type to be int or float vector type")); } TEST_F(ValidateImage, SampleProjImplicitLodWrongNumComponentsResultType) { @@ -1477,8 +1453,7 @@ TEST_F(ValidateImage, SampleProjImplicitLodWrongNumComponentsResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to have 4 components: " - "ImageSampleProjImplicitLod")); + HasSubstr("Expected Result Type to have 4 components")); } TEST_F(ValidateImage, SampleProjImplicitLodNotSampledImage) { @@ -1491,8 +1466,7 @@ TEST_F(ValidateImage, SampleProjImplicitLodNotSampledImage) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage: " - "ImageSampleProjImplicitLod")); + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); } TEST_F(ValidateImage, SampleProjImplicitLodWrongSampledType) { @@ -1507,8 +1481,7 @@ TEST_F(ValidateImage, SampleProjImplicitLodWrongSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image 'Sampled Type' to be the same as " - "Result Type components: " - "ImageSampleProjImplicitLod")); + "Result Type components")); } TEST_F(ValidateImage, SampleProjImplicitLodVoidSampledType) { @@ -1534,8 +1507,7 @@ TEST_F(ValidateImage, SampleProjImplicitLodWrongCoordinateType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Coordinate to be float scalar or vector: " - "ImageSampleProjImplicitLod")); + HasSubstr("Expected Coordinate to be float scalar or vector")); } TEST_F(ValidateImage, SampleProjImplicitLodCoordinateSizeTooSmall) { @@ -1550,8 +1522,7 @@ TEST_F(ValidateImage, SampleProjImplicitLodCoordinateSizeTooSmall) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Coordinate to have at least 3 components, " - "but given only 2: " - "ImageSampleProjImplicitLod")); + "but given only 2")); } TEST_F(ValidateImage, SampleDrefImplicitLodSuccess) { @@ -1582,8 +1553,7 @@ TEST_F(ValidateImage, SampleDrefImplicitLodWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to be int or float scalar type: " - "ImageSampleDrefImplicitLod")); + HasSubstr("Expected Result Type to be int or float scalar type")); } TEST_F(ValidateImage, SampleDrefImplicitLodNotSampledImage) { @@ -1596,8 +1566,7 @@ TEST_F(ValidateImage, SampleDrefImplicitLodNotSampledImage) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage: " - "ImageSampleDrefImplicitLod")); + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); } TEST_F(ValidateImage, SampleDrefImplicitLodWrongSampledType) { @@ -1612,8 +1581,7 @@ TEST_F(ValidateImage, SampleDrefImplicitLodWrongSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type: " - "ImageSampleDrefImplicitLod")); + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); } TEST_F(ValidateImage, SampleDrefImplicitLodVoidSampledType) { @@ -1628,8 +1596,7 @@ TEST_F(ValidateImage, SampleDrefImplicitLodVoidSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type: " - "ImageSampleDrefImplicitLod")); + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); } TEST_F(ValidateImage, SampleDrefImplicitLodWrongCoordinateType) { @@ -1643,8 +1610,7 @@ TEST_F(ValidateImage, SampleDrefImplicitLodWrongCoordinateType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Coordinate to be float scalar or vector: " - "ImageSampleDrefImplicitLod")); + HasSubstr("Expected Coordinate to be float scalar or vector")); } TEST_F(ValidateImage, SampleDrefImplicitLodCoordinateSizeTooSmall) { @@ -1659,8 +1625,7 @@ TEST_F(ValidateImage, SampleDrefImplicitLodCoordinateSizeTooSmall) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Coordinate to have at least 2 components, " - "but given only 1: " - "ImageSampleDrefImplicitLod")); + "but given only 1")); } TEST_F(ValidateImage, SampleDrefImplicitLodWrongDrefType) { @@ -1674,8 +1639,7 @@ TEST_F(ValidateImage, SampleDrefImplicitLodWrongDrefType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSampleDrefImplicitLod: " - "Expected Dref to be of 32-bit float type")); + HasSubstr("Expected Dref to be of 32-bit float type")); } TEST_F(ValidateImage, SampleDrefExplicitLodSuccess) { @@ -1705,8 +1669,7 @@ TEST_F(ValidateImage, SampleDrefExplicitLodWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to be int or float scalar type: " - "ImageSampleDrefExplicitLod")); + HasSubstr("Expected Result Type to be int or float scalar type")); } TEST_F(ValidateImage, SampleDrefExplicitLodNotSampledImage) { @@ -1719,8 +1682,7 @@ TEST_F(ValidateImage, SampleDrefExplicitLodNotSampledImage) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage: " - "ImageSampleDrefExplicitLod")); + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); } TEST_F(ValidateImage, SampleDrefExplicitLodWrongSampledType) { @@ -1735,8 +1697,7 @@ TEST_F(ValidateImage, SampleDrefExplicitLodWrongSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type: " - "ImageSampleDrefExplicitLod")); + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); } TEST_F(ValidateImage, SampleDrefExplicitLodVoidSampledType) { @@ -1751,8 +1712,7 @@ TEST_F(ValidateImage, SampleDrefExplicitLodVoidSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type: " - "ImageSampleDrefExplicitLod")); + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); } TEST_F(ValidateImage, SampleDrefExplicitLodWrongCoordinateType) { @@ -1766,8 +1726,7 @@ TEST_F(ValidateImage, SampleDrefExplicitLodWrongCoordinateType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Coordinate to be float scalar or vector: " - "ImageSampleDrefExplicitLod")); + HasSubstr("Expected Coordinate to be float scalar or vector")); } TEST_F(ValidateImage, SampleDrefExplicitLodCoordinateSizeTooSmall) { @@ -1782,8 +1741,7 @@ TEST_F(ValidateImage, SampleDrefExplicitLodCoordinateSizeTooSmall) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Coordinate to have at least 3 components, " - "but given only 2: " - "ImageSampleDrefExplicitLod")); + "but given only 2")); } TEST_F(ValidateImage, SampleDrefExplicitLodWrongDrefType) { @@ -1797,8 +1755,7 @@ TEST_F(ValidateImage, SampleDrefExplicitLodWrongDrefType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSampleDrefExplicitLod: " - "Expected Dref to be of 32-bit float type")); + HasSubstr("Expected Dref to be of 32-bit float type")); } TEST_F(ValidateImage, SampleProjDrefImplicitLodSuccess) { @@ -1829,8 +1786,7 @@ TEST_F(ValidateImage, SampleProjDrefImplicitLodWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to be int or float scalar type: " - "ImageSampleProjDrefImplicitLod")); + HasSubstr("Expected Result Type to be int or float scalar type")); } TEST_F(ValidateImage, SampleProjDrefImplicitLodNotSampledImage) { @@ -1843,8 +1799,7 @@ TEST_F(ValidateImage, SampleProjDrefImplicitLodNotSampledImage) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage: " - "ImageSampleProjDrefImplicitLod")); + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); } TEST_F(ValidateImage, SampleProjDrefImplicitLodWrongSampledType) { @@ -1859,8 +1814,7 @@ TEST_F(ValidateImage, SampleProjDrefImplicitLodWrongSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type: " - "ImageSampleProjDrefImplicitLod")); + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); } TEST_F(ValidateImage, SampleProjDrefImplicitLodVoidSampledType) { @@ -1875,8 +1829,7 @@ TEST_F(ValidateImage, SampleProjDrefImplicitLodVoidSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type: " - "ImageSampleProjDrefImplicitLod")); + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); } TEST_F(ValidateImage, SampleProjDrefImplicitLodWrongCoordinateType) { @@ -1890,8 +1843,7 @@ TEST_F(ValidateImage, SampleProjDrefImplicitLodWrongCoordinateType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Coordinate to be float scalar or vector: " - "ImageSampleProjDrefImplicitLod")); + HasSubstr("Expected Coordinate to be float scalar or vector")); } TEST_F(ValidateImage, SampleProjDrefImplicitLodCoordinateSizeTooSmall) { @@ -1906,8 +1858,7 @@ TEST_F(ValidateImage, SampleProjDrefImplicitLodCoordinateSizeTooSmall) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Coordinate to have at least 3 components, " - "but given only 2: " - "ImageSampleProjDrefImplicitLod")); + "but given only 2")); } TEST_F(ValidateImage, SampleProjDrefImplicitLodWrongDrefType) { @@ -1921,8 +1872,7 @@ TEST_F(ValidateImage, SampleProjDrefImplicitLodWrongDrefType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSampleProjDrefImplicitLod: " - "Expected Dref to be of 32-bit float type")); + HasSubstr("Expected Dref to be of 32-bit float type")); } TEST_F(ValidateImage, SampleProjDrefExplicitLodSuccess) { @@ -1952,8 +1902,7 @@ TEST_F(ValidateImage, SampleProjDrefExplicitLodWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to be int or float scalar type: " - "ImageSampleProjDrefExplicitLod")); + HasSubstr("Expected Result Type to be int or float scalar type")); } TEST_F(ValidateImage, SampleProjDrefExplicitLodNotSampledImage) { @@ -1966,8 +1915,7 @@ TEST_F(ValidateImage, SampleProjDrefExplicitLodNotSampledImage) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage: " - "ImageSampleProjDrefExplicitLod")); + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); } TEST_F(ValidateImage, SampleProjDrefExplicitLodWrongSampledType) { @@ -1982,8 +1930,7 @@ TEST_F(ValidateImage, SampleProjDrefExplicitLodWrongSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type: " - "ImageSampleProjDrefExplicitLod")); + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); } TEST_F(ValidateImage, SampleProjDrefExplicitLodVoidSampledType) { @@ -1998,8 +1945,7 @@ TEST_F(ValidateImage, SampleProjDrefExplicitLodVoidSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type: " - "ImageSampleProjDrefExplicitLod")); + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); } TEST_F(ValidateImage, SampleProjDrefExplicitLodWrongCoordinateType) { @@ -2013,8 +1959,7 @@ TEST_F(ValidateImage, SampleProjDrefExplicitLodWrongCoordinateType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Coordinate to be float scalar or vector: " - "ImageSampleProjDrefExplicitLod")); + HasSubstr("Expected Coordinate to be float scalar or vector")); } TEST_F(ValidateImage, SampleProjDrefExplicitLodCoordinateSizeTooSmall) { @@ -2029,8 +1974,7 @@ TEST_F(ValidateImage, SampleProjDrefExplicitLodCoordinateSizeTooSmall) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Coordinate to have at least 2 components, " - "but given only 1: " - "ImageSampleProjDrefExplicitLod")); + "but given only 1")); } TEST_F(ValidateImage, FetchSuccess) { @@ -2052,8 +1996,7 @@ TEST_F(ValidateImage, FetchWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to be int or float vector type: " - "ImageFetch")); + HasSubstr("Expected Result Type to be int or float vector type")); } TEST_F(ValidateImage, FetchWrongNumComponentsResultType) { @@ -2064,9 +2007,8 @@ TEST_F(ValidateImage, FetchWrongNumComponentsResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Result Type to have 4 components: ImageFetch")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to have 4 components")); } TEST_F(ValidateImage, FetchNotImage) { @@ -2079,9 +2021,8 @@ TEST_F(ValidateImage, FetchNotImage) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Image to be of type OpTypeImage: ImageFetch")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); } TEST_F(ValidateImage, FetchNotSampled) { @@ -2092,9 +2033,8 @@ TEST_F(ValidateImage, FetchNotSampled) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Image 'Sampled' parameter to be 1: ImageFetch")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled' parameter to be 1")); } TEST_F(ValidateImage, FetchCube) { @@ -2105,8 +2045,7 @@ TEST_F(ValidateImage, FetchCube) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("Image 'Dim' cannot be Cube: ImageFetch")); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Image 'Dim' cannot be Cube")); } TEST_F(ValidateImage, FetchWrongSampledType) { @@ -2119,8 +2058,7 @@ TEST_F(ValidateImage, FetchWrongSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image 'Sampled Type' to be the same as " - "Result Type components: " - "ImageFetch")); + "Result Type components")); } TEST_F(ValidateImage, FetchVoidSampledType) { @@ -2144,8 +2082,7 @@ TEST_F(ValidateImage, FetchWrongCoordinateType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Coordinate to be int scalar or vector: " - "ImageFetch")); + HasSubstr("Expected Coordinate to be int scalar or vector")); } TEST_F(ValidateImage, FetchCoordinateSizeTooSmall) { @@ -2158,8 +2095,7 @@ TEST_F(ValidateImage, FetchCoordinateSizeTooSmall) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Coordinate to have at least 2 components, " - "but given only 1: " - "ImageFetch")); + "but given only 1")); } TEST_F(ValidateImage, FetchLodNotInt) { @@ -2199,8 +2135,7 @@ TEST_F(ValidateImage, GatherWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to be int or float vector type: " - "ImageGather")); + HasSubstr("Expected Result Type to be int or float vector type")); } TEST_F(ValidateImage, GatherWrongNumComponentsResultType) { @@ -2214,8 +2149,7 @@ TEST_F(ValidateImage, GatherWrongNumComponentsResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to have 4 components: " - "ImageGather")); + HasSubstr("Expected Result Type to have 4 components")); } TEST_F(ValidateImage, GatherNotSampledImage) { @@ -2228,8 +2162,7 @@ TEST_F(ValidateImage, GatherNotSampledImage) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage: " - "ImageGather")); + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); } TEST_F(ValidateImage, GatherWrongSampledType) { @@ -2244,8 +2177,7 @@ TEST_F(ValidateImage, GatherWrongSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image 'Sampled Type' to be the same as " - "Result Type components: " - "ImageGather")); + "Result Type components")); } TEST_F(ValidateImage, GatherVoidSampledType) { @@ -2271,8 +2203,7 @@ TEST_F(ValidateImage, GatherWrongCoordinateType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Coordinate to be float scalar or vector: " - "ImageGather")); + HasSubstr("Expected Coordinate to be float scalar or vector")); } TEST_F(ValidateImage, GatherCoordinateSizeTooSmall) { @@ -2287,8 +2218,7 @@ TEST_F(ValidateImage, GatherCoordinateSizeTooSmall) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Coordinate to have at least 4 components, " - "but given only 1: " - "ImageGather")); + "but given only 1")); } TEST_F(ValidateImage, GatherWrongComponentType) { @@ -2302,8 +2232,7 @@ TEST_F(ValidateImage, GatherWrongComponentType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Component to be 32-bit int scalar: " - "ImageGather")); + HasSubstr("Expected Component to be 32-bit int scalar")); } TEST_F(ValidateImage, GatherComponentNot32Bit) { @@ -2317,8 +2246,7 @@ TEST_F(ValidateImage, GatherComponentNot32Bit) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Component to be 32-bit int scalar: " - "ImageGather")); + HasSubstr("Expected Component to be 32-bit int scalar")); } TEST_F(ValidateImage, GatherDimCube) { @@ -2334,8 +2262,7 @@ TEST_F(ValidateImage, GatherDimCube) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "Image Operand ConstOffsets cannot be used with Cube Image 'Dim': " - "ImageGather")); + "Image Operand ConstOffsets cannot be used with Cube Image 'Dim'")); } TEST_F(ValidateImage, GatherConstOffsetsNotArray) { @@ -2350,8 +2277,8 @@ TEST_F(ValidateImage, GatherConstOffsetsNotArray) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image Operand ConstOffsets to be an array of size 4: " - "ImageGather")); + HasSubstr( + "Expected Image Operand ConstOffsets to be an array of size 4")); } TEST_F(ValidateImage, GatherConstOffsetsArrayWrongSize) { @@ -2366,8 +2293,8 @@ TEST_F(ValidateImage, GatherConstOffsetsArrayWrongSize) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image Operand ConstOffsets to be an array of size 4: " - "ImageGather")); + HasSubstr( + "Expected Image Operand ConstOffsets to be an array of size 4")); } TEST_F(ValidateImage, GatherConstOffsetsArrayNotVector) { @@ -2382,8 +2309,7 @@ TEST_F(ValidateImage, GatherConstOffsetsArrayNotVector) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image Operand ConstOffsets array componenets " - "to be int vectors " - "of size 2: ImageGather")); + "to be int vectors of size 2")); } TEST_F(ValidateImage, GatherConstOffsetsArrayVectorWrongSize) { @@ -2398,8 +2324,7 @@ TEST_F(ValidateImage, GatherConstOffsetsArrayVectorWrongSize) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image Operand ConstOffsets array componenets " - "to be int vectors " - "of size 2: ImageGather")); + "to be int vectors of size 2")); } TEST_F(ValidateImage, GatherConstOffsetsArrayNotConst) { @@ -2415,8 +2340,7 @@ TEST_F(ValidateImage, GatherConstOffsetsArrayNotConst) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image Operand ConstOffsets to be a const object: " - "ImageGather")); + HasSubstr("Expected Image Operand ConstOffsets to be a const object")); } TEST_F(ValidateImage, NotGatherWithConstOffsets) { @@ -2433,7 +2357,7 @@ TEST_F(ValidateImage, NotGatherWithConstOffsets) { getDiagnosticString(), HasSubstr( "Image Operand ConstOffsets can only be used with OpImageGather " - "and OpImageDrefGather: ImageSampleImplicitLod")); + "and OpImageDrefGather")); } TEST_F(ValidateImage, DrefGatherSuccess) { @@ -2461,8 +2385,7 @@ TEST_F(ValidateImage, DrefGatherVoidSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image 'Sampled Type' to be the same as " - "Result Type components: " - "ImageDrefGather")); + "Result Type components")); } TEST_F(ValidateImage, DrefGatherWrongDrefType) { @@ -2476,8 +2399,7 @@ TEST_F(ValidateImage, DrefGatherWrongDrefType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageDrefGather: " - "Expected Dref to be of 32-bit float type")); + HasSubstr("Expected Dref to be of 32-bit float type")); } TEST_F(ValidateImage, ReadSuccess1) { @@ -2533,8 +2455,7 @@ TEST_F(ValidateImage, ReadNeedCapabilityStorageImageReadWithoutFormat) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Capability StorageImageReadWithoutFormat is required " - "to read storage " - "image: ImageRead")); + "to read storage image")); } TEST_F(ValidateImage, ReadNeedCapabilityImage1D) { @@ -2547,8 +2468,7 @@ TEST_F(ValidateImage, ReadNeedCapabilityImage1D) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr( - "Capability Image1D is required to access storage image: ImageRead")); + HasSubstr("Capability Image1D is required to access storage image")); } TEST_F(ValidateImage, ReadNeedCapabilityImageCubeArray) { @@ -2562,8 +2482,7 @@ TEST_F(ValidateImage, ReadNeedCapabilityImageCubeArray) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "Capability ImageCubeArray is required to access storage image: " - "ImageRead")); + "Capability ImageCubeArray is required to access storage image")); } // TODO(atgoo@github.com) Disabled until the spec is clarified. @@ -2576,10 +2495,8 @@ TEST_F(ValidateImage, DISABLED_ReadWrongResultType) { const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr( - "Expected Result Type to be int or float vector type: ImageRead")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int or float vector type")); } // TODO(atgoo@github.com) Disabled until the spec is clarified. @@ -2592,9 +2509,8 @@ TEST_F(ValidateImage, DISABLED_ReadWrongNumComponentsResultType) { const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Result Type to have 4 components: ImageRead")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to have 4 components")); } TEST_F(ValidateImage, ReadNotImage) { @@ -2607,7 +2523,7 @@ TEST_F(ValidateImage, ReadNotImage) { CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Image to be of type OpTypeImage: ImageRead")); + HasSubstr("Expected Image to be of type OpTypeImage")); } TEST_F(ValidateImage, ReadImageSampled) { @@ -2619,9 +2535,8 @@ TEST_F(ValidateImage, ReadImageSampled) { const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Image 'Sampled' parameter to be 0 or 2: ImageRead")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled' parameter to be 0 or 2")); } TEST_F(ValidateImage, ReadWrongSampledType) { @@ -2635,8 +2550,7 @@ TEST_F(ValidateImage, ReadWrongSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image 'Sampled Type' to be the same as " - "Result Type components: " - "ImageRead")); + "Result Type components")); } TEST_F(ValidateImage, ReadVoidSampledType) { @@ -2661,9 +2575,8 @@ TEST_F(ValidateImage, ReadWrongCoordinateType) { const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Coordinate to be int scalar or vector: ImageRead")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be int scalar or vector")); } TEST_F(ValidateImage, ReadCoordinateSizeTooSmall) { @@ -2677,8 +2590,7 @@ TEST_F(ValidateImage, ReadCoordinateSizeTooSmall) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Coordinate to have at least 2 components, " - "but given only 1: " - "ImageRead")); + "but given only 1")); } TEST_F(ValidateImage, WriteSuccess1) { @@ -2736,7 +2648,7 @@ TEST_F(ValidateImage, WriteSubpassData) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Image 'Dim' cannot be SubpassData: ImageWrite")); + HasSubstr("Image 'Dim' cannot be SubpassData")); } TEST_F(ValidateImage, WriteNeedCapabilityStorageImageWriteWithoutFormat) { @@ -2751,7 +2663,7 @@ TEST_F(ValidateImage, WriteNeedCapabilityStorageImageWriteWithoutFormat) { getDiagnosticString(), HasSubstr( "Capability StorageImageWriteWithoutFormat is required to write to " - "storage image: ImageWrite")); + "storage image")); } TEST_F(ValidateImage, WriteNeedCapabilityImage1D) { @@ -2764,7 +2676,7 @@ TEST_F(ValidateImage, WriteNeedCapabilityImage1D) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Capability Image1D is required to access storage " - "image: ImageWrite")); + "image")); } TEST_F(ValidateImage, WriteNeedCapabilityImageCubeArray) { @@ -2778,8 +2690,7 @@ TEST_F(ValidateImage, WriteNeedCapabilityImageCubeArray) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "Capability ImageCubeArray is required to access storage image: " - "ImageWrite")); + "Capability ImageCubeArray is required to access storage image")); } TEST_F(ValidateImage, WriteNotImage) { @@ -2790,9 +2701,8 @@ TEST_F(ValidateImage, WriteNotImage) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Image to be of type OpTypeImage: ImageWrite")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); } TEST_F(ValidateImage, WriteImageSampled) { @@ -2804,9 +2714,8 @@ TEST_F(ValidateImage, WriteImageSampled) { const std::string extra = "\nOpCapability StorageImageWriteWithoutFormat\n"; CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Image 'Sampled' parameter to be 0 or 2: ImageWrite")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled' parameter to be 0 or 2")); } TEST_F(ValidateImage, WriteWrongCoordinateType) { @@ -2818,9 +2727,8 @@ TEST_F(ValidateImage, WriteWrongCoordinateType) { const std::string extra = "\nOpCapability StorageImageWriteWithoutFormat\n"; CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Coordinate to be int scalar or vector: ImageWrite")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be int scalar or vector")); } TEST_F(ValidateImage, WriteCoordinateSizeTooSmall) { @@ -2834,8 +2742,7 @@ TEST_F(ValidateImage, WriteCoordinateSizeTooSmall) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Coordinate to have at least 2 components, " - "but given only 1: " - "ImageWrite")); + "but given only 1")); } TEST_F(ValidateImage, WriteTexelWrongType) { @@ -2847,10 +2754,8 @@ TEST_F(ValidateImage, WriteTexelWrongType) { const std::string extra = "\nOpCapability StorageImageWriteWithoutFormat\n"; CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr( - "Expected Texel to be int or float vector or scalar: ImageWrite")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Texel to be int or float vector or scalar")); } TEST_F(ValidateImage, DISABLED_WriteTexelNotVector4) { @@ -2863,7 +2768,7 @@ TEST_F(ValidateImage, DISABLED_WriteTexelNotVector4) { CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Texel to have 4 components: ImageWrite")); + HasSubstr("Expected Texel to have 4 components")); } TEST_F(ValidateImage, WriteTexelWrongComponentType) { @@ -2878,8 +2783,7 @@ TEST_F(ValidateImage, WriteTexelWrongComponentType) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "Expected Image 'Sampled Type' to be the same as Texel components: " - "ImageWrite")); + "Expected Image 'Sampled Type' to be the same as Texel components")); } TEST_F(ValidateImage, WriteSampleNotInteger) { @@ -2892,8 +2796,7 @@ TEST_F(ValidateImage, WriteSampleNotInteger) { CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Image Operand Sample to be int scalar: " - "ImageWrite")); + HasSubstr("Expected Image Operand Sample to be int scalar")); } TEST_F(ValidateImage, SampleNotMultisampled) { @@ -2907,8 +2810,7 @@ TEST_F(ValidateImage, SampleNotMultisampled) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr( - "Image Operand Sample requires non-zero 'MS' parameter: ImageWrite")); + HasSubstr("Image Operand Sample requires non-zero 'MS' parameter")); } TEST_F(ValidateImage, SampleWrongOpcode) { @@ -2924,8 +2826,7 @@ TEST_F(ValidateImage, SampleWrongOpcode) { EXPECT_THAT(getDiagnosticString(), HasSubstr("Image Operand Sample can only be used with " "OpImageFetch, OpImageRead, OpImageWrite, " - "OpImageSparseFetch and OpImageSparseRead: " - "ImageSampleExplicitLod")); + "OpImageSparseFetch and OpImageSparseRead")); } TEST_F(ValidateImage, SampleImageToImageSuccess) { @@ -2951,7 +2852,7 @@ TEST_F(ValidateImage, SampleImageToImageWrongResultType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to be OpTypeImage: Image")); + HasSubstr("Expected Result Type to be OpTypeImage")); } TEST_F(ValidateImage, SampleImageToImageNotSampledImage) { @@ -2964,8 +2865,7 @@ TEST_F(ValidateImage, SampleImageToImageNotSampledImage) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr( - "Expected Sample Image to be of type OpTypeSampleImage: Image")); + HasSubstr("Expected Sample Image to be of type OpTypeSampleImage")); } TEST_F(ValidateImage, SampleImageToImageNotTheSameImageType) { @@ -2980,7 +2880,7 @@ TEST_F(ValidateImage, SampleImageToImageNotTheSameImageType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Sample Image image type to be equal to " - "Result Type: Image")); + "Result Type")); } TEST_F(ValidateImage, QueryFormatSuccess) { @@ -3001,10 +2901,8 @@ TEST_F(ValidateImage, QueryFormatWrongResultType) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr( - "Expected Result Type to be int scalar type: ImageQueryFormat")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int scalar type")); } TEST_F(ValidateImage, QueryFormatNotImage) { @@ -3017,10 +2915,8 @@ TEST_F(ValidateImage, QueryFormatNotImage) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr( - "Expected operand to be of type OpTypeImage: ImageQueryFormat")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected operand to be of type OpTypeImage")); } TEST_F(ValidateImage, QueryOrderSuccess) { @@ -3041,9 +2937,8 @@ TEST_F(ValidateImage, QueryOrderWrongResultType) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Result Type to be int scalar type: ImageQueryOrder")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int scalar type")); } TEST_F(ValidateImage, QueryOrderNotImage) { @@ -3056,9 +2951,8 @@ TEST_F(ValidateImage, QueryOrderNotImage) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected operand to be of type OpTypeImage: ImageQueryOrder")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected operand to be of type OpTypeImage")); } TEST_F(ValidateImage, QuerySizeLodSuccess) { @@ -3079,9 +2973,9 @@ TEST_F(ValidateImage, QuerySizeLodWrongResultType) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to be int scalar or vector type: " - "ImageQuerySizeLod")); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type to be int scalar or vector type")); } TEST_F(ValidateImage, QuerySizeLodResultTypeWrongSize) { @@ -3092,10 +2986,8 @@ TEST_F(ValidateImage, QuerySizeLodResultTypeWrongSize) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr( - "Result Type has 1 components, but 2 expected: ImageQuerySizeLod")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Result Type has 1 components, but 2 expected")); } TEST_F(ValidateImage, QuerySizeLodNotImage) { @@ -3108,9 +3000,8 @@ TEST_F(ValidateImage, QuerySizeLodNotImage) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Image to be of type OpTypeImage: ImageQuerySizeLod")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); } TEST_F(ValidateImage, QuerySizeLodWrongImageDim) { @@ -3121,9 +3012,8 @@ TEST_F(ValidateImage, QuerySizeLodWrongImageDim) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Image 'Dim' must be 1D, 2D, 3D or Cube: ImageQuerySizeLod")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image 'Dim' must be 1D, 2D, 3D or Cube")); } TEST_F(ValidateImage, QuerySizeLodMultisampled) { @@ -3134,8 +3024,7 @@ TEST_F(ValidateImage, QuerySizeLodMultisampled) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("Image 'MS' must be 0: ImageQuerySizeLod")); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Image 'MS' must be 0")); } TEST_F(ValidateImage, QuerySizeLodWrongLodType) { @@ -3147,8 +3036,7 @@ TEST_F(ValidateImage, QuerySizeLodWrongLodType) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Level of Detail to be int scalar: " - "ImageQuerySizeLod")); + HasSubstr("Expected Level of Detail to be int scalar")); } TEST_F(ValidateImage, QuerySizeSuccess) { @@ -3169,9 +3057,9 @@ TEST_F(ValidateImage, QuerySizeWrongResultType) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected Result Type to be int scalar or vector type: " - "ImageQuerySize")); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type to be int scalar or vector type")); } TEST_F(ValidateImage, QuerySizeNotImage) { @@ -3184,9 +3072,8 @@ TEST_F(ValidateImage, QuerySizeNotImage) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Image to be of type OpTypeImage: ImageQuerySize")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); } // TODO(atgoo@github.com) Add more tests for OpQuerySize. @@ -3226,9 +3113,8 @@ TEST_F(ValidateImage, QueryLodWrongResultType) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Result Type to be float vector type: ImageQueryLod")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be float vector type")); } TEST_F(ValidateImage, QueryLodResultTypeWrongSize) { @@ -3241,9 +3127,8 @@ TEST_F(ValidateImage, QueryLodResultTypeWrongSize) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Result Type to have 2 components: ImageQueryLod")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to have 2 components")); } TEST_F(ValidateImage, QueryLodNotSampledImage) { @@ -3256,8 +3141,7 @@ TEST_F(ValidateImage, QueryLodNotSampledImage) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Image operand to be of type OpTypeSampledImage: " - "ImageQueryLod")); + HasSubstr("Expected Image operand to be of type OpTypeSampledImage")); } TEST_F(ValidateImage, QueryLodWrongDim) { @@ -3270,9 +3154,8 @@ TEST_F(ValidateImage, QueryLodWrongDim) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Image 'Dim' must be 1D, 2D, 3D or Cube: ImageQueryLod")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image 'Dim' must be 1D, 2D, 3D or Cube")); } TEST_F(ValidateImage, QueryLodWrongCoordinateType) { @@ -3285,10 +3168,8 @@ TEST_F(ValidateImage, QueryLodWrongCoordinateType) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr( - "Expected Coordinate to be float scalar or vector: ImageQueryLod")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be float scalar or vector")); } TEST_F(ValidateImage, QueryLodCoordinateSizeTooSmall) { @@ -3303,8 +3184,7 @@ TEST_F(ValidateImage, QueryLodCoordinateSizeTooSmall) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Coordinate to have at least 2 components, " - "but given only 1: " - "ImageQueryLod")); + "but given only 1")); } TEST_F(ValidateImage, QueryLevelsSuccess) { @@ -3325,10 +3205,8 @@ TEST_F(ValidateImage, QueryLevelsWrongResultType) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr( - "Expected Result Type to be int scalar type: ImageQueryLevels")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int scalar type")); } TEST_F(ValidateImage, QueryLevelsNotImage) { @@ -3341,9 +3219,8 @@ TEST_F(ValidateImage, QueryLevelsNotImage) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Expected Image to be of type OpTypeImage: ImageQueryLevels")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); } TEST_F(ValidateImage, QueryLevelsWrongDim) { @@ -3354,9 +3231,8 @@ TEST_F(ValidateImage, QueryLevelsWrongDim) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr("Image 'Dim' must be 1D, 2D, 3D or Cube: ImageQueryLevels")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image 'Dim' must be 1D, 2D, 3D or Cube")); } TEST_F(ValidateImage, QuerySamplesSuccess) { @@ -3377,8 +3253,7 @@ TEST_F(ValidateImage, QuerySamplesNot2D) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("Image 'Dim' must be 2D: ImageQuerySamples")); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Image 'Dim' must be 2D")); } TEST_F(ValidateImage, QuerySamplesNotMultisampled) { @@ -3389,8 +3264,7 @@ TEST_F(ValidateImage, QuerySamplesNotMultisampled) { CompileSuccessfully(GenerateKernelCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("Image 'MS' must be 1: ImageQuerySamples")); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Image 'MS' must be 1")); } TEST_F(ValidateImage, QueryLodWrongExecutionModel) { @@ -3450,10 +3324,8 @@ TEST_F(ValidateImage, ReadSubpassDataWrongExecutionModel) { const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; CompileSuccessfully(GenerateShaderCode(body, extra, "Vertex").c_str()); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - HasSubstr( - "Dim SubpassData requires Fragment execution model: ImageRead")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Dim SubpassData requires Fragment execution model")); } TEST_F(ValidateImage, SparseSampleImplicitLodSuccess) { @@ -3484,8 +3356,7 @@ TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeNotStruct) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseSampleImplicitLod: " - "expected Result Type to be OpTypeStruct")); + HasSubstr("Expected Result Type to be OpTypeStruct")); } TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeNotTwoMembers1) { @@ -3499,8 +3370,8 @@ TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeNotTwoMembers1) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseSampleImplicitLod: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an int " + "scalar and a texel")); } TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeNotTwoMembers2) { @@ -3514,8 +3385,8 @@ TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeNotTwoMembers2) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseSampleImplicitLod: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); } TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeFirstMemberNotInt) { @@ -3529,8 +3400,8 @@ TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeFirstMemberNotInt) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseSampleImplicitLod: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); } TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeTexelNotVector) { @@ -3545,7 +3416,7 @@ TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeTexelNotVector) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Result Type's second member to be int or " - "float vector type: ImageSparseSampleImplicitLod")); + "float vector type")); } TEST_F(ValidateImage, SparseSampleImplicitLodWrongNumComponentsTexel) { @@ -3560,7 +3431,7 @@ TEST_F(ValidateImage, SparseSampleImplicitLodWrongNumComponentsTexel) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Result Type's second member to have 4 " - "components: ImageSparseSampleImplicitLod")); + "components")); } TEST_F(ValidateImage, SparseSampleImplicitLodWrongComponentTypeTexel) { @@ -3575,8 +3446,7 @@ TEST_F(ValidateImage, SparseSampleImplicitLodWrongComponentTypeTexel) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image 'Sampled Type' to be the same as " - "Result Type's second member components: " - "ImageSparseSampleImplicitLod")); + "Result Type's second member components")); } TEST_F(ValidateImage, SparseSampleDrefImplicitLodSuccess) { @@ -3607,8 +3477,7 @@ TEST_F(ValidateImage, SparseSampleDrefImplicitLodResultTypeNotStruct) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseSampleDrefImplicitLod: " - "expected Result Type to be OpTypeStruct")); + HasSubstr("Expected Result Type to be OpTypeStruct")); } TEST_F(ValidateImage, SparseSampleDrefImplicitLodResultTypeNotTwoMembers1) { @@ -3623,8 +3492,8 @@ TEST_F(ValidateImage, SparseSampleDrefImplicitLodResultTypeNotTwoMembers1) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("ImageSparseSampleDrefImplicitLod: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an int scalar " + "and a texel")); } TEST_F(ValidateImage, SparseSampleDrefImplicitLodResultTypeNotTwoMembers2) { @@ -3639,8 +3508,8 @@ TEST_F(ValidateImage, SparseSampleDrefImplicitLodResultTypeNotTwoMembers2) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("ImageSparseSampleDrefImplicitLod: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an int scalar " + "and a texel")); } TEST_F(ValidateImage, SparseSampleDrefImplicitLodResultTypeFirstMemberNotInt) { @@ -3655,8 +3524,8 @@ TEST_F(ValidateImage, SparseSampleDrefImplicitLodResultTypeFirstMemberNotInt) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("ImageSparseSampleDrefImplicitLod: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an int scalar " + "and a texel")); } TEST_F(ValidateImage, SparseSampleDrefImplicitLodDifferentSampledType) { @@ -3671,8 +3540,7 @@ TEST_F(ValidateImage, SparseSampleDrefImplicitLodDifferentSampledType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image 'Sampled Type' to be the same as " - "Result Type's second member: " - "ImageSparseSampleDrefImplicitLod")); + "Result Type's second member")); } TEST_F(ValidateImage, SparseFetchSuccess) { @@ -3694,8 +3562,7 @@ TEST_F(ValidateImage, SparseFetchResultTypeNotStruct) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseFetch: " - "expected Result Type to be OpTypeStruct")); + HasSubstr("Expected Result Type to be OpTypeStruct")); } TEST_F(ValidateImage, SparseFetchResultTypeNotTwoMembers1) { @@ -3707,8 +3574,8 @@ TEST_F(ValidateImage, SparseFetchResultTypeNotTwoMembers1) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseFetch: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); } TEST_F(ValidateImage, SparseFetchResultTypeNotTwoMembers2) { @@ -3720,8 +3587,8 @@ TEST_F(ValidateImage, SparseFetchResultTypeNotTwoMembers2) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseFetch: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); } TEST_F(ValidateImage, SparseFetchResultTypeFirstMemberNotInt) { @@ -3733,8 +3600,8 @@ TEST_F(ValidateImage, SparseFetchResultTypeFirstMemberNotInt) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseFetch: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); } TEST_F(ValidateImage, SparseFetchResultTypeTexelNotVector) { @@ -3747,7 +3614,7 @@ TEST_F(ValidateImage, SparseFetchResultTypeTexelNotVector) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Result Type's second member to be int or " - "float vector type: ImageSparseFetch")); + "float vector type")); } TEST_F(ValidateImage, SparseFetchWrongNumComponentsTexel) { @@ -3760,7 +3627,7 @@ TEST_F(ValidateImage, SparseFetchWrongNumComponentsTexel) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Result Type's second member to have 4 " - "components: ImageSparseFetch")); + "components")); } TEST_F(ValidateImage, SparseFetchWrongComponentTypeTexel) { @@ -3773,8 +3640,7 @@ TEST_F(ValidateImage, SparseFetchWrongComponentTypeTexel) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image 'Sampled Type' to be the same as " - "Result Type's second member components: " - "ImageSparseFetch")); + "Result Type's second member components")); } TEST_F(ValidateImage, SparseReadSuccess) { @@ -3798,8 +3664,7 @@ TEST_F(ValidateImage, SparseReadResultTypeNotStruct) { CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseRead: " - "expected Result Type to be OpTypeStruct")); + HasSubstr("Expected Result Type to be OpTypeStruct")); } TEST_F(ValidateImage, SparseReadResultTypeNotTwoMembers1) { @@ -3812,8 +3677,8 @@ TEST_F(ValidateImage, SparseReadResultTypeNotTwoMembers1) { CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseRead: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); } TEST_F(ValidateImage, SparseReadResultTypeNotTwoMembers2) { @@ -3826,8 +3691,8 @@ TEST_F(ValidateImage, SparseReadResultTypeNotTwoMembers2) { CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseRead: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); } TEST_F(ValidateImage, SparseReadResultTypeFirstMemberNotInt) { @@ -3840,8 +3705,8 @@ TEST_F(ValidateImage, SparseReadResultTypeFirstMemberNotInt) { CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseRead: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); } TEST_F(ValidateImage, SparseReadResultTypeTexelWrongType) { @@ -3855,7 +3720,7 @@ TEST_F(ValidateImage, SparseReadResultTypeTexelWrongType) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Result Type's second member to be int or " - "float scalar or vector type: ImageSparseRead")); + "float scalar or vector type")); } TEST_F(ValidateImage, SparseReadWrongComponentTypeTexel) { @@ -3869,8 +3734,7 @@ TEST_F(ValidateImage, SparseReadWrongComponentTypeTexel) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image 'Sampled Type' to be the same as " - "Result Type's second member components: " - "ImageSparseRead")); + "Result Type's second member components")); } TEST_F(ValidateImage, SparseReadSubpassDataNotAllowed) { @@ -3910,8 +3774,7 @@ TEST_F(ValidateImage, SparseGatherResultTypeNotStruct) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseGather: " - "expected Result Type to be OpTypeStruct")); + HasSubstr("Expected Result Type to be OpTypeStruct")); } TEST_F(ValidateImage, SparseGatherResultTypeNotTwoMembers1) { @@ -3925,8 +3788,8 @@ TEST_F(ValidateImage, SparseGatherResultTypeNotTwoMembers1) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseGather: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an int " + "scalar and a texel")); } TEST_F(ValidateImage, SparseGatherResultTypeNotTwoMembers2) { @@ -3940,8 +3803,8 @@ TEST_F(ValidateImage, SparseGatherResultTypeNotTwoMembers2) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseGather: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an int " + "scalar and a texel")); } TEST_F(ValidateImage, SparseGatherResultTypeFirstMemberNotInt) { @@ -3955,8 +3818,8 @@ TEST_F(ValidateImage, SparseGatherResultTypeFirstMemberNotInt) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseGather: expected Result Type " - "to be a struct containing an int scalar and a texel")); + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); } TEST_F(ValidateImage, SparseGatherResultTypeTexelNotVector) { @@ -3971,7 +3834,7 @@ TEST_F(ValidateImage, SparseGatherResultTypeTexelNotVector) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Result Type's second member to be int or " - "float vector type: ImageSparseGather")); + "float vector type")); } TEST_F(ValidateImage, SparseGatherWrongNumComponentsTexel) { @@ -3986,7 +3849,7 @@ TEST_F(ValidateImage, SparseGatherWrongNumComponentsTexel) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Result Type's second member to have 4 " - "components: ImageSparseGather")); + "components")); } TEST_F(ValidateImage, SparseGatherWrongComponentTypeTexel) { @@ -4001,8 +3864,7 @@ TEST_F(ValidateImage, SparseGatherWrongComponentTypeTexel) { ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Expected Image 'Sampled Type' to be the same as " - "Result Type's second member components: " - "ImageSparseGather")); + "Result Type's second member components")); } TEST_F(ValidateImage, SparseTexelsResidentSuccess) { @@ -4022,8 +3884,9 @@ TEST_F(ValidateImage, SparseTexelsResidentResultTypeNotBool) { CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImageSparseTexelsResident: " - "expected Result Type to be bool scalar type")); + HasSubstr("Expected Result Type to be bool scalar type")); } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_interfaces_test.cpp b/3rdparty/spirv-tools/test/val/val_interfaces_test.cpp new file mode 100644 index 000000000..b673a040f --- /dev/null +++ b/3rdparty/spirv-tools/test/val/val_interfaces_test.cpp @@ -0,0 +1,164 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; + +using ValidateInterfacesTest = spvtest::ValidateBase; + +TEST_F(ValidateInterfacesTest, EntryPointMissingInput) { + std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Input %3 +%5 = OpVariable %4 Input +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +%8 = OpLoad %3 %5 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Input variable id <5> is used by entry point 'func' id <1>, " + "but is not listed as an interface")); +} + +TEST_F(ValidateInterfacesTest, EntryPointMissingOutput) { + std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Output %3 +%5 = OpVariable %4 Output +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +%8 = OpLoad %3 %5 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Output variable id <5> is used by entry point 'func' id <1>, " + "but is not listed as an interface")); +} + +TEST_F(ValidateInterfacesTest, InterfaceMissingUseInSubfunction) { + std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Input %3 +%5 = OpVariable %4 Input +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +%8 = OpFunctionCall %2 %9 +OpReturn +OpFunctionEnd +%9 = OpFunction %2 None %6 +%10 = OpLabel +%11 = OpLoad %3 %5 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Input variable id <5> is used by entry point 'func' id <1>, " + "but is not listed as an interface")); +} + +TEST_F(ValidateInterfacesTest, TwoEntryPointsOneFunction) { + std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" %2 +OpEntryPoint Fragment %1 "func2" +%3 = OpTypeVoid +%4 = OpTypeInt 32 0 +%5 = OpTypePointer Input %4 +%2 = OpVariable %5 Input +%6 = OpTypeFunction %3 +%1 = OpFunction %3 None %6 +%7 = OpLabel +%8 = OpLoad %4 %2 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Input variable id <2> is used by entry point 'func2' id <1>, " + "but is not listed as an interface")); +} + +TEST_F(ValidateInterfacesTest, MissingInterfaceThroughInitializer) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointers +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Input %3 +%5 = OpTypePointer Function %4 +%6 = OpVariable %4 Input +%7 = OpTypeFunction %2 +%1 = OpFunction %2 None %7 +%8 = OpLabel +%9 = OpVariable %5 Function %6 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Input variable id <6> is used by entry point 'func' id <1>, " + "but is not listed as an interface")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_layout_test.cpp b/3rdparty/spirv-tools/test/val/val_layout_test.cpp index b91454ce5..145c5439b 100644 --- a/3rdparty/spirv-tools/test/val/val_layout_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_layout_test.cpp @@ -14,36 +14,30 @@ // Validation tests for Logical Layout +#include #include #include #include +#include #include +#include #include "gmock/gmock.h" #include "source/diagnostic.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" -using std::function; -using std::ostream; -using std::ostream_iterator; -using std::pair; -using std::string; -using std::stringstream; -using std::tie; -using std::tuple; -using std::vector; +namespace spvtools { +namespace val { +namespace { -using libspirv::spvResultToString; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::StrEq; -using pred_type = function; -using ValidateLayout = - spvtest::ValidateBase>>; - -namespace { +using pred_type = std::function; +using ValidateLayout = spvtest::ValidateBase< + std::tuple>>; // returns true if order is equal to VAL template @@ -71,15 +65,16 @@ spv_result_t InvalidSet(int order) { } // SPIRV source used to test the logical layout -const vector& getInstructions() { +const std::vector& getInstructions() { // clang-format off - static const vector instructions = { + static const std::vector instructions = { "OpCapability Shader", "OpExtension \"TestExtension\"", "%inst = OpExtInstImport \"GLSL.std.450\"", "OpMemoryModel Logical GLSL450", "OpEntryPoint GLCompute %func \"\"", "OpExecutionMode %func LocalSize 1 1 1", + "OpExecutionModeId %func LocalSizeId %one %one %one", "%str = OpString \"Test String\"", "%str2 = OpString \"blabla\"", "OpSource GLSL 450 %str \"uniform vec3 var = vec3(4.0);\"", @@ -133,37 +128,38 @@ INSTANTIATE_TEST_CASE_P(InstructionsOrder, // validation error. Therefore, "Lines to compile" for some instructions // are not "All" in the below. // - // | Instruction | Line(s) valid | Lines to compile - ::testing::Values( make_tuple(string("OpCapability") , Equals<0> , Range<0, 2>()) - , make_tuple(string("OpExtension") , Equals<1> , All) - , make_tuple(string("OpExtInstImport") , Equals<2> , All) - , make_tuple(string("OpMemoryModel") , Equals<3> , Range<1, kRangeEnd>()) - , make_tuple(string("OpEntryPoint") , Equals<4> , All) - , make_tuple(string("OpExecutionMode") , Equals<5> , All) - , make_tuple(string("OpSource ") , Range<6, 10>() , Range<7, kRangeEnd>()) - , make_tuple(string("OpSourceContinued ") , Range<6, 10>() , All) - , make_tuple(string("OpSourceExtension ") , Range<6, 10>() , All) - , make_tuple(string("%str2 = OpString ") , Range<6, 10>() , All) - , make_tuple(string("OpName ") , Range<11, 12>() , All) - , make_tuple(string("OpMemberName ") , Range<11, 12>() , All) - , make_tuple(string("OpDecorate ") , Range<13, 16>() , All) - , make_tuple(string("OpMemberDecorate ") , Range<13, 16>() , All) - , make_tuple(string("OpGroupDecorate ") , Range<13, 16>() , Range<16, kRangeEnd>()) - , make_tuple(string("OpDecorationGroup") , Range<13, 16>() , Range<0, 15>()) - , make_tuple(string("OpTypeBool") , Range<17, 30>() , All) - , make_tuple(string("OpTypeVoid") , Range<17, 30>() , Range<0, 25>()) - , make_tuple(string("OpTypeFloat") , Range<17, 30>() , Range<0,20>()) - , make_tuple(string("OpTypeInt") , Range<17, 30>() , Range<0, 20>()) - , make_tuple(string("OpTypeVector %floatt 4") , Range<17, 30>() , Range<19, 23>()) - , make_tuple(string("OpTypeMatrix %vec4 4") , Range<17, 30>() , Range<22, kRangeEnd>()) - , make_tuple(string("OpTypeStruct") , Range<17, 30>() , Range<24, kRangeEnd>()) - , make_tuple(string("%vfunct = OpTypeFunction"), Range<17, 30>() , Range<20, 30>()) - , make_tuple(string("OpConstant") , Range<17, 30>() , Range<20, kRangeEnd>()) - , make_tuple(string("OpLine ") , Range<17, kRangeEnd>() , Range<7, kRangeEnd>()) - , make_tuple(string("OpNoLine") , Range<17, kRangeEnd>() , All) - , make_tuple(string("%fLabel = OpLabel") , Equals<38> , All) - , make_tuple(string("OpNop") , Equals<39> , Range<39,kRangeEnd>()) - , make_tuple(string("OpReturn ; %func2 return") , Equals<40> , All) + // | Instruction | Line(s) valid | Lines to compile + ::testing::Values(std::make_tuple(std::string("OpCapability") , Equals<0> , Range<0, 2>()) + , std::make_tuple(std::string("OpExtension") , Equals<1> , All) + , std::make_tuple(std::string("OpExtInstImport") , Equals<2> , All) + , std::make_tuple(std::string("OpMemoryModel") , Equals<3> , Range<1, kRangeEnd>()) + , std::make_tuple(std::string("OpEntryPoint") , Equals<4> , All) + , std::make_tuple(std::string("OpExecutionMode ") , Range<5, 6>() , All) + , std::make_tuple(std::string("OpExecutionModeId") , Range<5, 6>() , All) + , std::make_tuple(std::string("OpSource ") , Range<7, 11>() , Range<8, kRangeEnd>()) + , std::make_tuple(std::string("OpSourceContinued ") , Range<7, 11>() , All) + , std::make_tuple(std::string("OpSourceExtension ") , Range<7, 11>() , All) + , std::make_tuple(std::string("%str2 = OpString ") , Range<7, 11>() , All) + , std::make_tuple(std::string("OpName ") , Range<12, 13>() , All) + , std::make_tuple(std::string("OpMemberName ") , Range<12, 13>() , All) + , std::make_tuple(std::string("OpDecorate ") , Range<14, 17>() , All) + , std::make_tuple(std::string("OpMemberDecorate ") , Range<14, 17>() , All) + , std::make_tuple(std::string("OpGroupDecorate ") , Range<14, 17>() , Range<17, kRangeEnd>()) + , std::make_tuple(std::string("OpDecorationGroup") , Range<14, 17>() , Range<0, 16>()) + , std::make_tuple(std::string("OpTypeBool") , Range<18, 31>() , All) + , std::make_tuple(std::string("OpTypeVoid") , Range<18, 31>() , Range<0, 26>()) + , std::make_tuple(std::string("OpTypeFloat") , Range<18, 31>() , Range<0,21>()) + , std::make_tuple(std::string("OpTypeInt") , Range<18, 31>() , Range<0, 21>()) + , std::make_tuple(std::string("OpTypeVector %floatt 4") , Range<18, 31>() , Range<20, 24>()) + , std::make_tuple(std::string("OpTypeMatrix %vec4 4") , Range<18, 31>() , Range<23, kRangeEnd>()) + , std::make_tuple(std::string("OpTypeStruct") , Range<18, 31>() , Range<25, kRangeEnd>()) + , std::make_tuple(std::string("%vfunct = OpTypeFunction"), Range<18, 31>() , Range<21, 31>()) + , std::make_tuple(std::string("OpConstant") , Range<18, 31>() , Range<21, kRangeEnd>()) + , std::make_tuple(std::string("OpLine ") , Range<18, kRangeEnd>() , Range<8, kRangeEnd>()) + , std::make_tuple(std::string("OpNoLine") , Range<18, kRangeEnd>() , All) + , std::make_tuple(std::string("%fLabel = OpLabel") , Equals<39> , All) + , std::make_tuple(std::string("OpNop") , Equals<40> , Range<40,kRangeEnd>()) + , std::make_tuple(std::string("OpReturn ; %func2 return") , Equals<41> , All) )),); // clang-format on @@ -171,15 +167,16 @@ INSTANTIATE_TEST_CASE_P(InstructionsOrder, // instructions vector and reinserts it in the location specified by order. // NOTE: This will not work correctly if there are two instances of substr in // instructions -vector GenerateCode(string substr, int order) { - vector code(getInstructions().size()); - vector inst(1); - partition_copy(begin(getInstructions()), end(getInstructions()), begin(code), - begin(inst), [=](const string& str) { - return string::npos == str.find(substr); +std::vector GenerateCode(std::string substr, int order) { + std::vector code(getInstructions().size()); + std::vector inst(1); + partition_copy(std::begin(getInstructions()), std::end(getInstructions()), + std::begin(code), std::begin(inst), + [=](const std::string& str) { + return std::string::npos == str.find(substr); }); - code.insert(begin(code) + order, inst.front()); + code.insert(std::begin(code) + order, inst.front()); return code; } @@ -188,27 +185,29 @@ vector GenerateCode(string substr, int order) { // the SPIRV source formed by combining the vector "instructions". TEST_P(ValidateLayout, Layout) { int order; - string instruction; + std::string instruction; pred_type pred; pred_type test_pred; // Predicate to determine if the test should be build - tuple testCase; + std::tuple testCase; - tie(order, testCase) = GetParam(); - tie(instruction, pred, test_pred) = testCase; + std::tie(order, testCase) = GetParam(); + std::tie(instruction, pred, test_pred) = testCase; // Skip test which break the code generation if (test_pred(order)) return; - vector code = GenerateCode(instruction, order); + std::vector code = GenerateCode(instruction, order); - stringstream ss; - copy(begin(code), end(code), ostream_iterator(ss, "\n")); + std::stringstream ss; + std::copy(std::begin(code), std::end(code), + std::ostream_iterator(ss, "\n")); + const auto env = SPV_ENV_UNIVERSAL_1_3; // printf("code: \n%s\n", ss.str().c_str()); - CompileSuccessfully(ss.str()); + CompileSuccessfully(ss.str(), env); spv_result_t result; // clang-format off - ASSERT_EQ(pred(order), result = ValidateInstructions()) + ASSERT_EQ(pred(order), result = ValidateInstructions(env)) << "Actual: " << spvResultToString(result) << "\nExpected: " << spvResultToString(pred(order)) << "\nOrder: " << order @@ -217,8 +216,8 @@ TEST_P(ValidateLayout, Layout) { // clang-format on } -TEST_F(ValidateLayout, MemoryModelMissing) { - string str = R"( +TEST_F(ValidateLayout, MemoryModelMissingBeforeEntryPoint) { + std::string str = R"( OpCapability Matrix OpExtension "TestExtension" %inst = OpExtInstImport "GLSL.std.450" @@ -234,6 +233,30 @@ TEST_F(ValidateLayout, MemoryModelMissing) { "EntryPoint cannot appear before the memory model instruction")); } +TEST_F(ValidateLayout, MemoryModelMissing) { + char str[] = R"(OpCapability Linkage)"; + CompileSuccessfully(str, SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Missing required OpMemoryModel instruction")); +} + +TEST_F(ValidateLayout, MemoryModelSpecifiedTwice) { + char str[] = R"( + OpCapability Linkage + OpCapability Shader + OpMemoryModel Logical Simple + OpMemoryModel Logical Simple + )"; + + CompileSuccessfully(str, SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpMemoryModel should only be provided once")); +} + TEST_F(ValidateLayout, FunctionDefinitionBeforeDeclarationBad) { char str[] = R"( OpCapability Shader @@ -318,7 +341,7 @@ TEST_F(ValidateLayout, FuncParameterNotImmediatlyAfterFuncBad) { } TEST_F(ValidateLayout, OpUndefCanAppearInTypeDeclarationSection) { - string str = R"( + std::string str = R"( OpCapability Kernel OpCapability Linkage OpMemoryModel Logical OpenCL @@ -337,7 +360,7 @@ TEST_F(ValidateLayout, OpUndefCanAppearInTypeDeclarationSection) { } TEST_F(ValidateLayout, OpUndefCanAppearInBlock) { - string str = R"( + std::string str = R"( OpCapability Kernel OpCapability Linkage OpMemoryModel Logical OpenCL @@ -474,6 +497,21 @@ TEST_F(ValidateEntryPoint, FunctionIsTargetOfEntryPointAndFunctionCallBad) { "instruction and an OpFunctionCall instruction.")); } +// Invalid. Must be within a function to make a function call. +TEST_F(ValidateEntryPoint, FunctionCallOutsideFunctionBody) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpName %variableName "variableName" + %34 = OpFunctionCall %variableName %1 + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("FunctionCall must happen within a function body.")); +} + // Valid. Module with a function but no entry point is valid when Linkage // Capability is used. TEST_F(ValidateEntryPoint, NoEntryPointWithLinkageCapGood) { @@ -610,4 +648,7 @@ TEST_F(ValidateLayout, ModuleProcessedInvalidInBasicBlock) { } // TODO(umar): Test optional instructions + } // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_limits_test.cpp b/3rdparty/spirv-tools/test/val/val_limits_test.cpp index 791ffa075..55bf1e5f1 100644 --- a/3rdparty/spirv-tools/test/val/val_limits_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_limits_test.cpp @@ -19,25 +19,26 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { -using std::string; using ::testing::HasSubstr; using ::testing::MatchesRegex; using ValidateLimits = spvtest::ValidateBase; -string header = R"( +std::string header = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 )"; TEST_F(ValidateLimits, IdLargerThanBoundBad) { - string str = header + R"( + std::string str = header + R"( ; %i32 has ID 1 %i32 = OpTypeInt 32 1 %c = OpConstant %i32 100 @@ -55,7 +56,7 @@ TEST_F(ValidateLimits, IdLargerThanBoundBad) { } TEST_F(ValidateLimits, IdEqualToBoundBad) { - string str = header + R"( + std::string str = header + R"( ; %i32 has ID 1 %i32 = OpTypeInt 32 1 %c = OpConstant %i32 100 @@ -406,8 +407,11 @@ TEST_F(ValidateLimits, CustomizedNumGlobalVarsBad) { } // Valid: module has 524,287 local variables. -TEST_F(ValidateLimits, NumLocalVarsGood) { - int num_locals = 524287; +// Note: AppVeyor limits process time to 300s. For a VisualStudio Debug +// build, going up to 524287 local variables gets too close to that +// limit. So test with an artificially lowered limit. +TEST_F(ValidateLimits, NumLocalVarsGoodArtificiallyLowLimit5K) { + int num_locals = 5000; std::ostringstream spirv; spirv << header << R"( %int = OpTypeInt 32 0 @@ -428,12 +432,16 @@ TEST_F(ValidateLimits, NumLocalVarsGood) { )"; CompileSuccessfully(spirv.str()); + // Artificially limit it. + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_local_variables, num_locals); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } // Invalid: module has 524,288 local variables (limit is 524,287). -TEST_F(ValidateLimits, NumLocalVarsBad) { - int num_locals = 524288; +// Artificially limit the check to 5001. +TEST_F(ValidateLimits, NumLocalVarsBadArtificiallyLowLimit5K) { + int num_locals = 5001; std::ostringstream spirv; spirv << header << R"( %int = OpTypeInt 32 0 @@ -454,10 +462,12 @@ TEST_F(ValidateLimits, NumLocalVarsBad) { )"; CompileSuccessfully(spirv.str()); + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_local_variables, 5000u); EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Number of local variables ('Function' Storage Class) " - "exceeded the valid limit (524287).")); + "exceeded the valid limit (5000).")); } // Valid: module has 100 local variables (limit is 100). @@ -683,7 +693,7 @@ TEST_F(ValidateLimits, CustomizedControlFlowDepthBad) { // continue target is the loop iteself. It also exercises the case where a loop // is unreachable. TEST_F(ValidateLimits, ControlFlowNoEntryToLoopGood) { - string str = header + R"( + std::string str = header + R"( OpName %entry "entry" OpName %loop "loop" OpName %exit "exit" @@ -703,4 +713,6 @@ TEST_F(ValidateLimits, ControlFlowNoEntryToLoopGood) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_literals_test.cpp b/3rdparty/spirv-tools/test/val/val_literals_test.cpp index 8a7981b43..cbdbdd10e 100644 --- a/3rdparty/spirv-tools/test/val/val_literals_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_literals_test.cpp @@ -18,7 +18,11 @@ #include #include "gmock/gmock.h" -#include "val_fixtures.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { using ::testing::HasSubstr; @@ -26,8 +30,6 @@ using ValidateLiterals = spvtest::ValidateBase; using ValidateLiteralsShader = spvtest::ValidateBase; using ValidateLiteralsKernel = spvtest::ValidateBase; -namespace { - std::string GenerateShaderCode() { std::string str = R"( OpCapability Shader @@ -136,3 +138,5 @@ INSTANTIATE_TEST_CASE_P( "%2 = OpConstant %uint8 !0xABCDEFFF")); } // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_logicals_test.cpp b/3rdparty/spirv-tools/test/val/val_logicals_test.cpp index 8464a216a..449cdd54f 100644 --- a/3rdparty/spirv-tools/test/val/val_logicals_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_logicals_test.cpp @@ -17,9 +17,11 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::HasSubstr; @@ -916,4 +918,6 @@ TEST_F(ValidateLogicals, OpSGreaterThanDifferentBitWidth) { "width: SGreaterThan")); } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_non_uniform_test.cpp b/3rdparty/spirv-tools/test/val/val_non_uniform_test.cpp new file mode 100644 index 000000000..6ff5c127b --- /dev/null +++ b/3rdparty/spirv-tools/test/val/val_non_uniform_test.cpp @@ -0,0 +1,252 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::Combine; +using ::testing::HasSubstr; +using ::testing::Values; +using ::testing::ValuesIn; + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "", + const std::string& execution_model = "GLCompute") { + std::ostringstream ss; + ss << R"( +OpCapability Shader +OpCapability GroupNonUniform +OpCapability GroupNonUniformVote +OpCapability GroupNonUniformBallot +OpCapability GroupNonUniformShuffle +OpCapability GroupNonUniformShuffleRelative +OpCapability GroupNonUniformArithmetic +OpCapability GroupNonUniformClustered +OpCapability GroupNonUniformQuad +)"; + + ss << capabilities_and_extensions; + ss << "OpMemoryModel Logical GLSL450\n"; + ss << "OpEntryPoint " << execution_model << " %main \"main\"\n"; + + ss << R"( +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%u32 = OpTypeInt 32 0 +%float = OpTypeFloat 32 +%u32vec4 = OpTypeVector %u32 4 + +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool + +%u32_0 = OpConstant %u32 0 + +%float_0 = OpConstant %float 0 + +%u32vec4_null = OpConstantComposite %u32vec4 %u32_0 %u32_0 %u32_0 %u32_0 + +%cross_device = OpConstant %u32 0 +%device = OpConstant %u32 1 +%workgroup = OpConstant %u32 2 +%subgroup = OpConstant %u32 3 +%invocation = OpConstant %u32 4 + +%reduce = OpConstant %u32 0 +%inclusive_scan = OpConstant %u32 1 +%exclusive_scan = OpConstant %u32 2 +%clustered_reduce = OpConstant %u32 3 + +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +SpvScope scopes[] = {SpvScopeCrossDevice, SpvScopeDevice, SpvScopeWorkgroup, + SpvScopeSubgroup, SpvScopeInvocation}; + +using GroupNonUniformScope = spvtest::ValidateBase< + std::tuple>; + +std::string ConvertScope(SpvScope scope) { + switch (scope) { + case SpvScopeCrossDevice: + return "%cross_device"; + case SpvScopeDevice: + return "%device"; + case SpvScopeWorkgroup: + return "%workgroup"; + case SpvScopeSubgroup: + return "%subgroup"; + case SpvScopeInvocation: + return "%invocation"; + default: + return ""; + } +} + +TEST_P(GroupNonUniformScope, Vulkan1p1) { + std::string opcode = std::get<0>(GetParam()); + std::string type = std::get<1>(GetParam()); + SpvScope execution_scope = std::get<2>(GetParam()); + std::string args = std::get<3>(GetParam()); + + std::ostringstream sstr; + sstr << "%result = " << opcode << " "; + sstr << type << " "; + sstr << ConvertScope(execution_scope) << " "; + sstr << args << "\n"; + + CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_VULKAN_1_1); + spv_result_t result = ValidateInstructions(SPV_ENV_VULKAN_1_1); + if (execution_scope == SpvScopeSubgroup) { + EXPECT_EQ(SPV_SUCCESS, result); + } else { + EXPECT_EQ(SPV_ERROR_INVALID_DATA, result); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "in Vulkan environment Execution scope is limited to Subgroup")); + } +} + +TEST_P(GroupNonUniformScope, Spirv1p3) { + std::string opcode = std::get<0>(GetParam()); + std::string type = std::get<1>(GetParam()); + SpvScope execution_scope = std::get<2>(GetParam()); + std::string args = std::get<3>(GetParam()); + + std::ostringstream sstr; + sstr << "%result = " << opcode << " "; + sstr << type << " "; + sstr << ConvertScope(execution_scope) << " "; + sstr << args << "\n"; + + CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_UNIVERSAL_1_3); + spv_result_t result = ValidateInstructions(SPV_ENV_UNIVERSAL_1_3); + if (execution_scope == SpvScopeSubgroup || + execution_scope == SpvScopeWorkgroup) { + EXPECT_EQ(SPV_SUCCESS, result); + } else { + EXPECT_EQ(SPV_ERROR_INVALID_DATA, result); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Execution scope is limited to Subgroup or Workgroup")); + } +} + +INSTANTIATE_TEST_CASE_P(GroupNonUniformElect, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformElect"), + Values("%bool"), ValuesIn(scopes), Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformVote, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformAll", + "OpGroupNonUniformAny", + "OpGroupNonUniformAllEqual"), + Values("%bool"), ValuesIn(scopes), + Values("%true"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBroadcast, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformBroadcast"), + Values("%bool"), ValuesIn(scopes), + Values("%true %u32_0"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBroadcastFirst, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformBroadcastFirst"), + Values("%bool"), ValuesIn(scopes), + Values("%true"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallot, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformBallot"), + Values("%u32vec4"), ValuesIn(scopes), + Values("%true"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformInverseBallot, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformInverseBallot"), + Values("%bool"), ValuesIn(scopes), + Values("%u32vec4_null"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitExtract, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformBallotBitExtract"), + Values("%bool"), ValuesIn(scopes), + Values("%u32vec4_null %u32_0"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitCount, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformBallotBitCount"), + Values("%u32"), ValuesIn(scopes), + Values("Reduce %u32vec4_null"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotFind, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformBallotFindLSB", + "OpGroupNonUniformBallotFindMSB"), + Values("%u32"), ValuesIn(scopes), + Values("%u32vec4_null"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformShuffle, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformShuffle", + "OpGroupNonUniformShuffleXor", + "OpGroupNonUniformShuffleUp", + "OpGroupNonUniformShuffleDown"), + Values("%u32"), ValuesIn(scopes), + Values("%u32_0 %u32_0"))); + +INSTANTIATE_TEST_CASE_P( + GroupNonUniformIntegerArithmetic, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformIAdd", "OpGroupNonUniformIMul", + "OpGroupNonUniformSMin", "OpGroupNonUniformUMin", + "OpGroupNonUniformSMax", "OpGroupNonUniformUMax", + "OpGroupNonUniformBitwiseAnd", "OpGroupNonUniformBitwiseOr", + "OpGroupNonUniformBitwiseXor"), + Values("%u32"), ValuesIn(scopes), Values("Reduce %u32_0"))); + +INSTANTIATE_TEST_CASE_P( + GroupNonUniformFloatArithmetic, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformFAdd", "OpGroupNonUniformFMul", + "OpGroupNonUniformFMin", "OpGroupNonUniformFMax"), + Values("%float"), ValuesIn(scopes), Values("Reduce %float_0"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformLogicalArithmetic, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformLogicalAnd", + "OpGroupNonUniformLogicalOr", + "OpGroupNonUniformLogicalXor"), + Values("%bool"), ValuesIn(scopes), + Values("Reduce %true"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformQuad, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformQuadBroadcast", + "OpGroupNonUniformQuadSwap"), + Values("%u32"), ValuesIn(scopes), + Values("%u32_0 %u32_0"))); + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_primitives_test.cpp b/3rdparty/spirv-tools/test/val/val_primitives_test.cpp index 5ef4e63ef..f02ba8057 100644 --- a/3rdparty/spirv-tools/test/val/val_primitives_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_primitives_test.cpp @@ -16,9 +16,11 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::HasSubstr; @@ -310,4 +312,6 @@ OpEndStreamPrimitive %val1 "expected Stream to be constant instruction")); } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_ssa_test.cpp b/3rdparty/spirv-tools/test/val/val_ssa_test.cpp index f6a712fa2..25944f518 100644 --- a/3rdparty/spirv-tools/test/val/val_ssa_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_ssa_test.cpp @@ -19,18 +19,17 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { using ::testing::HasSubstr; using ::testing::MatchesRegex; -using std::pair; -using std::string; -using std::stringstream; - -namespace { -using ValidateSSA = spvtest::ValidateBase>; +using ValidateSSA = spvtest::ValidateBase>; TEST_F(ValidateSSA, Default) { char str[] = R"( @@ -119,7 +118,8 @@ TEST_F(ValidateSSA, DominateUsageWithinBlockBad) { CompileSuccessfully(str); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - MatchesRegex("ID .\\[bad\\] has not been defined")); + MatchesRegex("ID .\\[bad\\] has not been defined\n" + " %8 = OpIAdd %uint %uint_1 %bad\n")); } TEST_F(ValidateSSA, DominateUsageSameInstructionBad) { @@ -141,7 +141,8 @@ TEST_F(ValidateSSA, DominateUsageSameInstructionBad) { CompileSuccessfully(str); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - MatchesRegex("ID .\\[sum\\] has not been defined")); + MatchesRegex("ID .\\[sum\\] has not been defined\n" + " %sum = OpIAdd %uint %uint_1 %sum\n")); } TEST_F(ValidateSSA, ForwardNameGood) { @@ -545,14 +546,14 @@ TEST_F(ValidateSSA, ForwardBranchConditionalMissingTargetBad) { // Since Int8 requires the Kernel capability, the signedness of int types may // not be "1". -const string kHeader = R"( +const std::string kHeader = R"( OpCapability Int8 OpCapability DeviceEnqueue OpCapability Linkage OpMemoryModel Logical OpenCL )"; -const string kBasicTypes = R"( +const std::string kBasicTypes = R"( %voidt = OpTypeVoid %boolt = OpTypeBool %int8t = OpTypeInt 8 0 @@ -565,7 +566,7 @@ const string kBasicTypes = R"( %false = OpConstantFalse %boolt )"; -const string kKernelTypesAndConstants = R"( +const std::string kKernelTypesAndConstants = R"( %queuet = OpTypeQueue %three = OpConstant %uintt 3 @@ -590,14 +591,14 @@ const string kKernelTypesAndConstants = R"( %kfunct = OpTypeFunction %voidt %intptrt )"; -const string kKernelSetup = R"( +const std::string kKernelSetup = R"( %dqueue = OpGetDefaultQueue %queuet %ndval = OpBuildNDRange %ndt %gl %local %offset %revent = OpUndef %eventt )"; -const string kKernelDefinition = R"( +const std::string kKernelDefinition = R"( %kfunc = OpFunction %voidt None %kfunct %iparam = OpFunctionParameter %intptrt %kfuncl = OpLabel @@ -607,8 +608,8 @@ const string kKernelDefinition = R"( )"; TEST_F(ValidateSSA, EnqueueKernelGood) { - string str = kHeader + kBasicTypes + kKernelTypesAndConstants + - kKernelDefinition + R"( + std::string str = kHeader + kBasicTypes + kKernelTypesAndConstants + + kKernelDefinition + R"( %main = OpFunction %voidt None %vfunct %mainl = OpLabel )" + kKernelSetup + R"( @@ -623,11 +624,11 @@ TEST_F(ValidateSSA, EnqueueKernelGood) { } TEST_F(ValidateSSA, ForwardEnqueueKernelGood) { - string str = kHeader + kBasicTypes + kKernelTypesAndConstants + R"( + std::string str = kHeader + kBasicTypes + kKernelTypesAndConstants + R"( %main = OpFunction %voidt None %vfunct %mainl = OpLabel )" + - kKernelSetup + R"( + kKernelSetup + R"( %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent %event %revent %kfunc %firstp %psize %palign %lsize @@ -639,8 +640,8 @@ TEST_F(ValidateSSA, ForwardEnqueueKernelGood) { } TEST_F(ValidateSSA, EnqueueMissingFunctionBad) { - string str = kHeader + "OpName %kfunc \"kfunc\"" + kBasicTypes + - kKernelTypesAndConstants + R"( + std::string str = kHeader + "OpName %kfunc \"kfunc\"" + kBasicTypes + + kKernelTypesAndConstants + R"( %main = OpFunction %voidt None %vfunct %mainl = OpLabel )" + kKernelSetup + R"( @@ -655,25 +656,25 @@ TEST_F(ValidateSSA, EnqueueMissingFunctionBad) { EXPECT_THAT(getDiagnosticString(), HasSubstr("kfunc")); } -string forwardKernelNonDominantParameterBaseCode(string name = string()) { - string op_name; +std::string forwardKernelNonDominantParameterBaseCode( + std::string name = std::string()) { + std::string op_name; if (name.empty()) { op_name = ""; } else { op_name = "\nOpName %" + name + " \"" + name + "\"\n"; } - string out = kHeader + op_name + kBasicTypes + kKernelTypesAndConstants + - kKernelDefinition + - R"( + std::string out = kHeader + op_name + kBasicTypes + kKernelTypesAndConstants + + kKernelDefinition + + R"( %main = OpFunction %voidt None %vfunct %mainl = OpLabel - )" + - kKernelSetup; + )" + kKernelSetup; return out; } TEST_F(ValidateSSA, ForwardEnqueueKernelMissingParameter1Bad) { - string str = forwardKernelNonDominantParameterBaseCode("missing") + R"( + std::string str = forwardKernelNonDominantParameterBaseCode("missing") + R"( %err = OpEnqueueKernel %missing %dqueue %flags %ndval %nevent %event %revent %kfunc %firstp %psize %palign %lsize @@ -686,7 +687,7 @@ TEST_F(ValidateSSA, ForwardEnqueueKernelMissingParameter1Bad) { } TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter2Bad) { - string str = forwardKernelNonDominantParameterBaseCode("dqueue2") + R"( + std::string str = forwardKernelNonDominantParameterBaseCode("dqueue2") + R"( %err = OpEnqueueKernel %uintt %dqueue2 %flags %ndval %nevent %event %revent %kfunc %firstp %psize %palign %lsize @@ -700,7 +701,7 @@ TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter2Bad) { } TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter3Bad) { - string str = forwardKernelNonDominantParameterBaseCode("ndval2") + R"( + std::string str = forwardKernelNonDominantParameterBaseCode("ndval2") + R"( %err = OpEnqueueKernel %uintt %dqueue %flags %ndval2 %nevent %event %revent %kfunc %firstp %psize %palign %lsize @@ -714,7 +715,7 @@ TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter3Bad) { } TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter4Bad) { - string str = forwardKernelNonDominantParameterBaseCode("nevent2") + R"( + std::string str = forwardKernelNonDominantParameterBaseCode("nevent2") + R"( %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent2 %event %revent %kfunc %firstp %psize %palign %lsize @@ -728,7 +729,7 @@ TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter4Bad) { } TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter5Bad) { - string str = forwardKernelNonDominantParameterBaseCode("event2") + R"( + std::string str = forwardKernelNonDominantParameterBaseCode("event2") + R"( %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent %event2 %revent %kfunc %firstp %psize %palign %lsize @@ -742,7 +743,7 @@ TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter5Bad) { } TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter6Bad) { - string str = forwardKernelNonDominantParameterBaseCode("revent2") + R"( + std::string str = forwardKernelNonDominantParameterBaseCode("revent2") + R"( %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent %event %revent2 %kfunc %firstp %psize %palign %lsize @@ -756,7 +757,7 @@ TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter6Bad) { } TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter8Bad) { - string str = forwardKernelNonDominantParameterBaseCode("firstp2") + R"( + std::string str = forwardKernelNonDominantParameterBaseCode("firstp2") + R"( %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent %event %revent %kfunc %firstp2 %psize %palign %lsize @@ -770,7 +771,7 @@ TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter8Bad) { } TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter9Bad) { - string str = forwardKernelNonDominantParameterBaseCode("psize2") + R"( + std::string str = forwardKernelNonDominantParameterBaseCode("psize2") + R"( %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent %event %revent %kfunc %firstp %psize2 %palign %lsize @@ -784,7 +785,7 @@ TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter9Bad) { } TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter10Bad) { - string str = forwardKernelNonDominantParameterBaseCode("palign2") + R"( + std::string str = forwardKernelNonDominantParameterBaseCode("palign2") + R"( %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent %event %revent %kfunc %firstp %psize %palign2 %lsize @@ -798,7 +799,7 @@ TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter10Bad) { } TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter11Bad) { - string str = forwardKernelNonDominantParameterBaseCode("lsize2") + R"( + std::string str = forwardKernelNonDominantParameterBaseCode("lsize2") + R"( %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent %event %revent %kfunc %firstp %psize %palign %lsize2 @@ -814,7 +815,7 @@ TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter11Bad) { static const bool kWithNDrange = true; static const bool kNoNDrange = false; -pair cases[] = { +std::pair cases[] = { {"OpGetKernelNDrangeSubGroupCount", kWithNDrange}, {"OpGetKernelNDrangeMaxSubGroupSize", kWithNDrange}, {"OpGetKernelWorkGroupSize", kNoNDrange}, @@ -822,17 +823,17 @@ pair cases[] = { INSTANTIATE_TEST_CASE_P(KernelArgs, ValidateSSA, ::testing::ValuesIn(cases), ); -static const string return_instructions = R"( +static const std::string return_instructions = R"( OpReturn OpFunctionEnd )"; TEST_P(ValidateSSA, GetKernelGood) { - string instruction = GetParam().first; + std::string instruction = GetParam().first; bool with_ndrange = GetParam().second; - string ndrange_param = with_ndrange ? " %ndval " : " "; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; - stringstream ss; + std::stringstream ss; // clang-format off ss << forwardKernelNonDominantParameterBaseCode() + " %numsg = " << instruction + " %uintt" + ndrange_param + "%kfunc %firstp %psize %palign" @@ -844,12 +845,12 @@ TEST_P(ValidateSSA, GetKernelGood) { } TEST_P(ValidateSSA, ForwardGetKernelGood) { - string instruction = GetParam().first; + std::string instruction = GetParam().first; bool with_ndrange = GetParam().second; - string ndrange_param = with_ndrange ? " %ndval " : " "; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; // clang-format off - string str = kHeader + kBasicTypes + kKernelTypesAndConstants + + std::string str = kHeader + kBasicTypes + kKernelTypesAndConstants + R"( %main = OpFunction %voidt None %vfunct %mainl = OpLabel @@ -864,11 +865,11 @@ TEST_P(ValidateSSA, ForwardGetKernelGood) { } TEST_P(ValidateSSA, ForwardGetKernelMissingDefinitionBad) { - string instruction = GetParam().first; + std::string instruction = GetParam().first; bool with_ndrange = GetParam().second; - string ndrange_param = with_ndrange ? " %ndval " : " "; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; - stringstream ss; + std::stringstream ss; // clang-format off ss << forwardKernelNonDominantParameterBaseCode("missing") + " %numsg = " << instruction + " %uintt" + ndrange_param + "%missing %firstp %psize %palign" @@ -881,11 +882,11 @@ TEST_P(ValidateSSA, ForwardGetKernelMissingDefinitionBad) { } TEST_P(ValidateSSA, ForwardGetKernelNDrangeSubGroupCountMissingParameter1Bad) { - string instruction = GetParam().first; + std::string instruction = GetParam().first; bool with_ndrange = GetParam().second; - string ndrange_param = with_ndrange ? " %ndval " : " "; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; - stringstream ss; + std::stringstream ss; // clang-format off ss << forwardKernelNonDominantParameterBaseCode("missing") + " %numsg = " << instruction + " %missing" + ndrange_param + "%kfunc %firstp %psize %palign" @@ -899,11 +900,11 @@ TEST_P(ValidateSSA, ForwardGetKernelNDrangeSubGroupCountMissingParameter1Bad) { TEST_P(ValidateSSA, ForwardGetKernelNDrangeSubGroupCountNonDominantParameter2Bad) { - string instruction = GetParam().first; + std::string instruction = GetParam().first; bool with_ndrange = GetParam().second; - string ndrange_param = with_ndrange ? " %ndval2 " : " "; + std::string ndrange_param = with_ndrange ? " %ndval2 " : " "; - stringstream ss; + std::stringstream ss; // clang-format off ss << forwardKernelNonDominantParameterBaseCode("ndval2") + " %numsg = " << instruction + " %uintt" + ndrange_param + "%kfunc %firstp %psize %palign" @@ -920,11 +921,11 @@ TEST_P(ValidateSSA, TEST_P(ValidateSSA, ForwardGetKernelNDrangeSubGroupCountNonDominantParameter4Bad) { - string instruction = GetParam().first; + std::string instruction = GetParam().first; bool with_ndrange = GetParam().second; - string ndrange_param = with_ndrange ? " %ndval " : " "; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; - stringstream ss; + std::stringstream ss; // clang-format off ss << forwardKernelNonDominantParameterBaseCode("firstp2") + " %numsg = " << instruction + " %uintt" + ndrange_param + "%kfunc %firstp2 %psize %palign" @@ -939,11 +940,11 @@ TEST_P(ValidateSSA, TEST_P(ValidateSSA, ForwardGetKernelNDrangeSubGroupCountNonDominantParameter5Bad) { - string instruction = GetParam().first; + std::string instruction = GetParam().first; bool with_ndrange = GetParam().second; - string ndrange_param = with_ndrange ? " %ndval " : " "; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; - stringstream ss; + std::stringstream ss; // clang-format off ss << forwardKernelNonDominantParameterBaseCode("psize2") + " %numsg = " << instruction + " %uintt" + ndrange_param + "%kfunc %firstp %psize2 %palign" @@ -958,11 +959,11 @@ TEST_P(ValidateSSA, TEST_P(ValidateSSA, ForwardGetKernelNDrangeSubGroupCountNonDominantParameter6Bad) { - string instruction = GetParam().first; + std::string instruction = GetParam().first; bool with_ndrange = GetParam().second; - string ndrange_param = with_ndrange ? " %ndval " : " "; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; - stringstream ss; + std::stringstream ss; // clang-format off ss << forwardKernelNonDominantParameterBaseCode("palign2") + " %numsg = " << instruction + " %uintt" + ndrange_param + "%kfunc %firstp %psize %palign2" @@ -978,8 +979,8 @@ TEST_P(ValidateSSA, } TEST_F(ValidateSSA, PhiGood) { - string str = kHeader + kBasicTypes + - R"( + std::string str = kHeader + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %preheader = OpLabel %init = OpCopyObject %uintt %zero @@ -1001,8 +1002,8 @@ TEST_F(ValidateSSA, PhiGood) { } TEST_F(ValidateSSA, PhiMissingTypeBad) { - string str = kHeader + "OpName %missing \"missing\"" + kBasicTypes + - R"( + std::string str = kHeader + "OpName %missing \"missing\"" + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %preheader = OpLabel %init = OpCopyObject %uintt %zero @@ -1025,8 +1026,8 @@ TEST_F(ValidateSSA, PhiMissingTypeBad) { } TEST_F(ValidateSSA, PhiMissingIdBad) { - string str = kHeader + "OpName %missing \"missing\"" + kBasicTypes + - R"( + std::string str = kHeader + "OpName %missing \"missing\"" + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %preheader = OpLabel %init = OpCopyObject %uintt %zero @@ -1049,8 +1050,8 @@ TEST_F(ValidateSSA, PhiMissingIdBad) { } TEST_F(ValidateSSA, PhiMissingLabelBad) { - string str = kHeader + "OpName %missing \"missing\"" + kBasicTypes + - R"( + std::string str = kHeader + "OpName %missing \"missing\"" + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %preheader = OpLabel %init = OpCopyObject %uintt %zero @@ -1073,8 +1074,8 @@ TEST_F(ValidateSSA, PhiMissingLabelBad) { } TEST_F(ValidateSSA, IdDominatesItsUseGood) { - string str = kHeader + kBasicTypes + - R"( + std::string str = kHeader + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %entry = OpLabel %cond = OpSLessThan %boolt %one %ten @@ -1097,12 +1098,12 @@ TEST_F(ValidateSSA, IdDominatesItsUseGood) { } TEST_F(ValidateSSA, IdDoesNotDominateItsUseBad) { - string str = kHeader + - "OpName %eleven \"eleven\"\n" - "OpName %true_block \"true_block\"\n" - "OpName %false_block \"false_block\"" + - kBasicTypes + - R"( + std::string str = kHeader + + "OpName %eleven \"eleven\"\n" + "OpName %true_block \"true_block\"\n" + "OpName %false_block \"false_block\"" + + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %entry = OpLabel %cond = OpSLessThan %boolt %one %ten @@ -1124,12 +1125,13 @@ TEST_F(ValidateSSA, IdDoesNotDominateItsUseBad) { EXPECT_THAT( getDiagnosticString(), MatchesRegex("ID .\\[eleven\\] defined in block .\\[true_block\\] does " - "not dominate its use in block .\\[false_block\\]")); + "not dominate its use in block .\\[false_block\\]\n" + " %false_block = OpLabel\n")); } TEST_F(ValidateSSA, PhiUseDoesntDominateDefinitionGood) { - string str = kHeader + kBasicTypes + - R"( + std::string str = kHeader + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %entry = OpLabel %var_one = OpVariable %intptrt Function %one @@ -1156,8 +1158,8 @@ TEST_F(ValidateSSA, PhiUseDoesntDominateDefinitionGood) { TEST_F(ValidateSSA, PhiUseDoesntDominateUseOfPhiOperandUsedBeforeDefinitionBad) { - string str = kHeader + "OpName %inew \"inew\"" + kBasicTypes + - R"( + std::string str = kHeader + "OpName %inew \"inew\"" + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %entry = OpLabel %var_one = OpVariable %intptrt Function %one @@ -1182,14 +1184,15 @@ TEST_F(ValidateSSA, CompileSuccessfully(str); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - MatchesRegex("ID .\\[inew\\] has not been defined")); + MatchesRegex("ID .\\[inew\\] has not been defined\n" + " %19 = OpIAdd %uint %inew %uint_1\n")); } TEST_F(ValidateSSA, PhiUseMayComeFromNonDominatingBlockGood) { - string str = kHeader + "OpName %if_true \"if_true\"\n" + - "OpName %exit \"exit\"\n" + "OpName %copy \"copy\"\n" + - kBasicTypes + - R"( + std::string str = kHeader + "OpName %if_true \"if_true\"\n" + + "OpName %exit \"exit\"\n" + "OpName %copy \"copy\"\n" + + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %entry = OpLabel OpBranchConditional %false %if_true %exit @@ -1216,9 +1219,9 @@ TEST_F(ValidateSSA, PhiUsesItsOwnDefinitionGood) { // // Non-phi instructions can't use their own definitions, as // already checked in test DominateUsageSameInstructionBad. - string str = kHeader + "OpName %loop \"loop\"\n" + - "OpName %value \"value\"\n" + kBasicTypes + - R"( + std::string str = kHeader + "OpName %loop \"loop\"\n" + + "OpName %value \"value\"\n" + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %entry = OpLabel OpBranch %loop @@ -1235,11 +1238,12 @@ TEST_F(ValidateSSA, PhiUsesItsOwnDefinitionGood) { } TEST_F(ValidateSSA, PhiVariableDefNotDominatedByParentBlockBad) { - string str = kHeader + "OpName %if_true \"if_true\"\n" + - "OpName %if_false \"if_false\"\n" + "OpName %exit \"exit\"\n" + - "OpName %value \"phi\"\n" + "OpName %true_copy \"true_copy\"\n" + - "OpName %false_copy \"false_copy\"\n" + kBasicTypes + - R"( + std::string str = kHeader + "OpName %if_true \"if_true\"\n" + + "OpName %if_false \"if_false\"\n" + + "OpName %exit \"exit\"\n" + "OpName %value \"phi\"\n" + + "OpName %true_copy \"true_copy\"\n" + + "OpName %false_copy \"false_copy\"\n" + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %entry = OpLabel OpBranchConditional %false %if_true %if_false @@ -1264,12 +1268,14 @@ TEST_F(ValidateSSA, PhiVariableDefNotDominatedByParentBlockBad) { EXPECT_THAT( getDiagnosticString(), MatchesRegex("In OpPhi instruction .\\[phi\\], ID .\\[true_copy\\] " - "definition does not dominate its parent .\\[if_false\\]")); + "definition does not dominate its parent .\\[if_false\\]\n" + " %phi = OpPhi %bool %true_copy %if_false %false_copy " + "%if_true\n")); } TEST_F(ValidateSSA, PhiVariableDefDominatesButNotDefinedInParentBlock) { - string str = kHeader + "OpName %if_true \"if_true\"\n" + kBasicTypes + - R"( + std::string str = kHeader + "OpName %if_true \"if_true\"\n" + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %entry = OpLabel OpBranchConditional %false %if_true %if_false @@ -1298,8 +1304,8 @@ TEST_F(ValidateSSA, PhiVariableDefDominatesButNotDefinedInParentBlock) { TEST_F(ValidateSSA, DominanceCheckIgnoresUsesInUnreachableBlocksDefInBlockGood) { - string str = kHeader + kBasicTypes + - R"( + std::string str = kHeader + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %entry = OpLabel %def = OpCopyObject %boolt %false @@ -1316,8 +1322,9 @@ TEST_F(ValidateSSA, } TEST_F(ValidateSSA, PhiVariableUnreachableDefNotInParentBlock) { - string str = kHeader + "OpName %unreachable \"unreachable\"\n" + kBasicTypes + - R"( + std::string str = kHeader + "OpName %unreachable \"unreachable\"\n" + + kBasicTypes + + R"( %func = OpFunction %voidt None %vfunct %entry = OpLabel OpBranch %if_false @@ -1346,8 +1353,8 @@ TEST_F(ValidateSSA, PhiVariableUnreachableDefNotInParentBlock) { TEST_F(ValidateSSA, DominanceCheckIgnoresUsesInUnreachableBlocksDefIsParamGood) { - string str = kHeader + kBasicTypes + - R"( + std::string str = kHeader + kBasicTypes + + R"( %void_fn_int = OpTypeFunction %voidt %uintt %func = OpFunction %voidt None %void_fn_int %int_param = OpFunctionParameter %uintt @@ -1365,11 +1372,11 @@ TEST_F(ValidateSSA, } TEST_F(ValidateSSA, UseFunctionParameterFromOtherFunctionBad) { - string str = kHeader + - "OpName %first \"first\"\n" - "OpName %func \"func\"\n" + - "OpName %func2 \"func2\"\n" + kBasicTypes + - R"( + std::string str = kHeader + + "OpName %first \"first\"\n" + "OpName %func \"func\"\n" + + "OpName %func2 \"func2\"\n" + kBasicTypes + + R"( %viifunct = OpTypeFunction %voidt %uintt %uintt %func = OpFunction %voidt None %viifunct %first = OpFunctionParameter %uintt @@ -1389,14 +1396,15 @@ TEST_F(ValidateSSA, UseFunctionParameterFromOtherFunctionBad) { EXPECT_THAT( getDiagnosticString(), MatchesRegex("ID .\\[first\\] used in function .\\[func2\\] is used " - "outside of it's defining function .\\[func\\]")); + "outside of it's defining function .\\[func\\]\n" + " %func = OpFunction %void None %14\n")); } TEST_F(ValidateSSA, TypeForwardPointerForwardReference) { // See https://github.com/KhronosGroup/SPIRV-Tools/issues/429 // // ForwardPointers can references instructions that have not been defined - string str = R"( + std::string str = R"( OpCapability Kernel OpCapability Addresses OpCapability Linkage @@ -1412,7 +1420,7 @@ TEST_F(ValidateSSA, TypeForwardPointerForwardReference) { } TEST_F(ValidateSSA, TypeStructForwardReference) { - string str = R"( + std::string str = R"( OpCapability Kernel OpCapability Addresses OpCapability Linkage @@ -1432,4 +1440,7 @@ TEST_F(ValidateSSA, TypeStructForwardReference) { } // TODO(umar): OpGroupMemberDecorate + } // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_state_test.cpp b/3rdparty/spirv-tools/test/val/val_state_test.cpp index c63a0c5c6..699b224ad 100644 --- a/3rdparty/spirv-tools/test/val/val_state_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_state_test.cpp @@ -18,22 +18,22 @@ #include #include "gtest/gtest.h" -#include "latest_version_spirv_header.h" +#include "source/latest_version_spirv_header.h" -#include "enum_set.h" -#include "extensions.h" -#include "spirv_validator_options.h" -#include "val/construct.h" -#include "val/function.h" -#include "val/validation_state.h" -#include "validate.h" +#include "source/enum_set.h" +#include "source/extensions.h" +#include "source/spirv_validator_options.h" +#include "source/val/construct.h" +#include "source/val/function.h" +#include "source/val/validate.h" +#include "source/val/validation_state.h" +namespace spvtools { +namespace val { namespace { -using libspirv::CapabilitySet; -using libspirv::Extension; -using libspirv::ExtensionSet; -using libspirv::ValidationState_t; -using std::vector; + +// This is all we need for these tests. +static uint32_t kFakeBinary[] = {0}; // A test with a ValidationState_t member transparently. class ValidationStateTest : public testing::Test { @@ -41,7 +41,7 @@ class ValidationStateTest : public testing::Test { ValidationStateTest() : context_(spvContextCreate(SPV_ENV_UNIVERSAL_1_0)), options_(spvValidatorOptionsCreate()), - state_(context_, options_) {} + state_(context_, options_, kFakeBinary, 0) {} ~ValidationStateTest() { spvContextDestroy(context_); @@ -133,4 +133,7 @@ TEST_F(ValidationState_HasAnyOfExtensions, MultiCapMask) { EXPECT_TRUE(state_.HasAnyOfExtensions(set1)); EXPECT_FALSE(state_.HasAnyOfExtensions(set2)); } + } // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_storage_test.cpp b/3rdparty/spirv-tools/test/val/val_storage_test.cpp index 02ef6f58d..46b3ddcbb 100644 --- a/3rdparty/spirv-tools/test/val/val_storage_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_storage_test.cpp @@ -19,14 +19,15 @@ #include #include "gmock/gmock.h" -#include "val_fixtures.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { using ::testing::HasSubstr; - using ValidateStorage = spvtest::ValidateBase; -namespace { - TEST_F(ValidateStorage, FunctionStorageInsideFunction) { char str[] = R"( OpCapability Shader @@ -175,4 +176,7 @@ TEST_F(ValidateStorage, GenericVariableInsideFunction) { EXPECT_THAT(getDiagnosticString(), HasSubstr("OpVariable storage class cannot be Generic")); } + } // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_type_unique_test.cpp b/3rdparty/spirv-tools/test/val/val_type_unique_test.cpp index ad16f246d..67ceaddb8 100644 --- a/3rdparty/spirv-tools/test/val/val_type_unique_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_type_unique_test.cpp @@ -17,22 +17,22 @@ #include #include "gmock/gmock.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { using ::testing::HasSubstr; using ::testing::Not; -using std::string; - using ValidateTypeUnique = spvtest::ValidateBase; const spv_result_t kDuplicateTypeError = SPV_ERROR_INVALID_DATA; -const string& GetHeader() { - static const string header = R"( +const std::string& GetHeader() { + static const std::string header = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -64,8 +64,8 @@ OpMemoryModel Logical GLSL450 return header; } -const string& GetBody() { - static const string body = R"( +const std::string& GetBody() { + static const std::string body = R"( %main = OpFunction %voidt None %vfunct %mainl = OpLabel %a = OpIAdd %uintt %const3 %val3 @@ -90,19 +90,19 @@ OpFunctionEnd // Returns expected error string if |opcode| produces a duplicate type // declaration. -string GetErrorString(SpvOp opcode) { +std::string GetErrorString(SpvOp opcode) { return "Duplicate non-aggregate type declarations are not allowed. Opcode: " + std::string(spvOpcodeString(opcode)); } TEST_F(ValidateTypeUnique, success) { - string str = GetHeader() + GetBody(); + std::string str = GetHeader() + GetBody(); CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateTypeUnique, duplicate_void) { - string str = GetHeader() + R"( + std::string str = GetHeader() + R"( %boolt2 = OpTypeVoid )" + GetBody(); CompileSuccessfully(str.c_str()); @@ -111,7 +111,7 @@ TEST_F(ValidateTypeUnique, duplicate_void) { } TEST_F(ValidateTypeUnique, duplicate_bool) { - string str = GetHeader() + R"( + std::string str = GetHeader() + R"( %boolt2 = OpTypeBool )" + GetBody(); CompileSuccessfully(str.c_str()); @@ -120,7 +120,7 @@ TEST_F(ValidateTypeUnique, duplicate_bool) { } TEST_F(ValidateTypeUnique, duplicate_int) { - string str = GetHeader() + R"( + std::string str = GetHeader() + R"( %uintt2 = OpTypeInt 32 0 )" + GetBody(); CompileSuccessfully(str.c_str()); @@ -129,7 +129,7 @@ TEST_F(ValidateTypeUnique, duplicate_int) { } TEST_F(ValidateTypeUnique, duplicate_float) { - string str = GetHeader() + R"( + std::string str = GetHeader() + R"( %floatt2 = OpTypeFloat 32 )" + GetBody(); CompileSuccessfully(str.c_str()); @@ -138,7 +138,7 @@ TEST_F(ValidateTypeUnique, duplicate_float) { } TEST_F(ValidateTypeUnique, duplicate_vec3) { - string str = GetHeader() + R"( + std::string str = GetHeader() + R"( %vec3t2 = OpTypeVector %floatt 3 )" + GetBody(); CompileSuccessfully(str.c_str()); @@ -148,7 +148,7 @@ TEST_F(ValidateTypeUnique, duplicate_vec3) { } TEST_F(ValidateTypeUnique, duplicate_mat33) { - string str = GetHeader() + R"( + std::string str = GetHeader() + R"( %mat33t2 = OpTypeMatrix %vec3t 3 )" + GetBody(); CompileSuccessfully(str.c_str()); @@ -158,7 +158,7 @@ TEST_F(ValidateTypeUnique, duplicate_mat33) { } TEST_F(ValidateTypeUnique, duplicate_vfunc) { - string str = GetHeader() + R"( + std::string str = GetHeader() + R"( %vfunct2 = OpTypeFunction %voidt )" + GetBody(); CompileSuccessfully(str.c_str()); @@ -168,7 +168,7 @@ TEST_F(ValidateTypeUnique, duplicate_vfunc) { } TEST_F(ValidateTypeUnique, duplicate_pipe_storage) { - string str = R"( + std::string str = R"( OpCapability Addresses OpCapability Kernel OpCapability Linkage @@ -185,7 +185,7 @@ OpMemoryModel Physical32 OpenCL } TEST_F(ValidateTypeUnique, duplicate_named_barrier) { - string str = R"( + std::string str = R"( OpCapability Addresses OpCapability Kernel OpCapability Linkage @@ -201,7 +201,7 @@ OpMemoryModel Physical32 OpenCL } TEST_F(ValidateTypeUnique, duplicate_forward_pointer) { - string str = R"( + std::string str = R"( OpCapability Addresses OpCapability Kernel OpCapability GenericPointer @@ -219,7 +219,7 @@ OpTypeForwardPointer %ptr2 Generic } TEST_F(ValidateTypeUnique, duplicate_void_with_extension) { - string str = R"( + std::string str = R"( OpCapability Addresses OpCapability Kernel OpCapability Linkage @@ -236,7 +236,7 @@ OpMemoryModel Physical32 OpenCL } TEST_F(ValidateTypeUnique, DuplicatePointerTypesNoExtension) { - string str = R"( + std::string str = R"( OpCapability Shader OpCapability Linkage OpMemoryModel Logical GLSL450 @@ -245,13 +245,11 @@ OpMemoryModel Logical GLSL450 %ptr2 = OpTypePointer Input %u32 )"; CompileSuccessfully(str.c_str()); - ASSERT_EQ(kDuplicateTypeError, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr(GetErrorString(SpvOpTypePointer))); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } TEST_F(ValidateTypeUnique, DuplicatePointerTypesWithExtension) { - string str = R"( + std::string str = R"( OpCapability Shader OpCapability Linkage OpExtension "SPV_KHR_variable_pointers" @@ -266,4 +264,6 @@ OpMemoryModel Logical GLSL450 Not(HasSubstr(GetErrorString(SpvOpTypePointer)))); } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_validation_state_test.cpp b/3rdparty/spirv-tools/test/val/val_validation_state_test.cpp index 7af6a4609..68504c528 100644 --- a/3rdparty/spirv-tools/test/val/val_validation_state_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_validation_state_test.cpp @@ -17,13 +17,14 @@ #include #include "gmock/gmock.h" -#include "spirv_validator_options.h" -#include "unit_spirv.h" -#include "val_fixtures.h" +#include "source/spirv_validator_options.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" +namespace spvtools { +namespace val { namespace { -using std::string; using ::testing::HasSubstr; using ValidationStateTest = spvtest::ValidateBase; @@ -43,7 +44,7 @@ const char kVoidFVoid[] = // Tests that the instruction count in ValidationState is correct. TEST_F(ValidationStateTest, CheckNumInstructions) { - string spirv = string(header) + "%int = OpTypeInt 32 0"; + std::string spirv = std::string(header) + "%int = OpTypeInt 32 0"; CompileSuccessfully(spirv); EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); EXPECT_EQ(size_t(4), vstate_->ordered_instructions().size()); @@ -51,7 +52,7 @@ TEST_F(ValidationStateTest, CheckNumInstructions) { // Tests that the number of global variables in ValidationState is correct. TEST_F(ValidationStateTest, CheckNumGlobalVars) { - string spirv = string(header) + R"( + std::string spirv = std::string(header) + R"( %int = OpTypeInt 32 0 %_ptr_int = OpTypePointer Input %int %var_1 = OpVariable %_ptr_int Input @@ -64,7 +65,7 @@ TEST_F(ValidationStateTest, CheckNumGlobalVars) { // Tests that the number of local variables in ValidationState is correct. TEST_F(ValidationStateTest, CheckNumLocalVars) { - string spirv = string(header) + R"( + std::string spirv = std::string(header) + R"( %int = OpTypeInt 32 0 %_ptr_int = OpTypePointer Function %int %voidt = OpTypeVoid @@ -84,7 +85,7 @@ TEST_F(ValidationStateTest, CheckNumLocalVars) { // Tests that the "id bound" in ValidationState is correct. TEST_F(ValidationStateTest, CheckIdBound) { - string spirv = string(header) + R"( + std::string spirv = std::string(header) + R"( %int = OpTypeInt 32 0 %voidt = OpTypeVoid )"; @@ -95,8 +96,9 @@ TEST_F(ValidationStateTest, CheckIdBound) { // Tests that the entry_points in ValidationState is correct. TEST_F(ValidationStateTest, CheckEntryPoints) { - string spirv = string(header) + " OpEntryPoint Vertex %func \"shader\"" + - string(kVoidFVoid); + std::string spirv = std::string(header) + + " OpEntryPoint Vertex %func \"shader\"" + + std::string(kVoidFVoid); CompileSuccessfully(spirv); EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); EXPECT_EQ(size_t(1), vstate_->entry_points().size()); @@ -152,4 +154,6 @@ TEST_F(ValidationStateTest, CheckAccessChainIndexesLimitOption) { EXPECT_EQ(100u, options_->universal_limits_.max_access_chain_indexes); } -} // anonymous namespace +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_version_test.cpp b/3rdparty/spirv-tools/test/val/val_version_test.cpp index 8c0ec4e1a..fa252ac86 100644 --- a/3rdparty/spirv-tools/test/val/val_version_test.cpp +++ b/3rdparty/spirv-tools/test/val/val_version_test.cpp @@ -12,15 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include + #include "gmock/gmock.h" -#include "val_fixtures.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { -using namespace spvtest; using ::testing::HasSubstr; -using std::make_tuple; -using ValidateVersion = - ValidateBase>; +using ValidateVersion = spvtest::ValidateBase< + std::tuple>; const std::string vulkan_spirv = R"( OpCapability Shader @@ -63,6 +68,7 @@ std::string version(spv_target_env env) { return "1.2"; case SPV_ENV_UNIVERSAL_1_3: case SPV_ENV_VULKAN_1_1: + case SPV_ENV_WEBGPU_0: return "1.3"; default: return "0"; @@ -89,176 +95,184 @@ TEST_P(ValidateVersion, version) { INSTANTIATE_TEST_CASE_P(Universal, ValidateVersion, ::testing::Values( // Binary version, Target environment - make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_VULKAN_1_0, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_0, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_1, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_2, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_3, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_5, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_VULKAN_1_0, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_0, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_2, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_5, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_WEBGPU_0, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_VULKAN_1_0, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_0, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_1, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_2, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_3, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_5, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_VULKAN_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_2, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_3, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_5, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_WEBGPU_0, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_VULKAN_1_0, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_0, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_1, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_2, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_3, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_5, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_VULKAN_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_2, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_3, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_5, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_WEBGPU_0, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_VULKAN_1_0, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), - make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_0, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_1, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_2, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_3, vulkan_spirv, false), - make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_5, vulkan_spirv, false) + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_VULKAN_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_2, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_3, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_5, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_WEBGPU_0, vulkan_spirv, true) ) ); INSTANTIATE_TEST_CASE_P(Vulkan, ValidateVersion, ::testing::Values( // Binary version, Target environment - make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, true), - make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, true), - make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, true), - make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), - make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_VULKAN_1_0, vulkan_spirv, true), - make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), - make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_0, vulkan_spirv, true), - make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_1, vulkan_spirv, true), - make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_2, vulkan_spirv, true), - make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_3, vulkan_spirv, true), - make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_5, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_VULKAN_1_0, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_0, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_2, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_5, vulkan_spirv, true), - make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, false), - make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, false), - make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, false), - make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), - make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_VULKAN_1_0, vulkan_spirv, false), - make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), - make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_0, vulkan_spirv, false), - make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_1, vulkan_spirv, false), - make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_2, vulkan_spirv, false), - make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_3, vulkan_spirv, false), - make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_5, vulkan_spirv, false) + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_VULKAN_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_2, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_3, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_5, vulkan_spirv, false) ) ); INSTANTIATE_TEST_CASE_P(OpenCL, ValidateVersion, ::testing::Values( // Binary version, Target environment - make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_2_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_2_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_1_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_2_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_2_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_1_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, false), - make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, false), - make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_2_0, opencl_spirv, false), - make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_2_1, opencl_spirv, false), - make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, false), - make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, false), - make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_1_2, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_2_0, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_2_1, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_1_2, opencl_spirv, false), - make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_2_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_2_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_1_2, opencl_spirv, true) + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_1_2, opencl_spirv, true) ) ); INSTANTIATE_TEST_CASE_P(OpenCLEmbedded, ValidateVersion, ::testing::Values( // Binary version, Target environment - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_2_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_2_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_1_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_2_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_2_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_1_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, false), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, false), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_2_0, opencl_spirv, false), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_2_1, opencl_spirv, false), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, false), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, false), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), - make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_1_2, opencl_spirv, false) + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_2_0, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_2_1, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_1_2, opencl_spirv, false) ) ); // clang-format on + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/test/val/val_webgpu_test.cpp b/3rdparty/spirv-tools/test/val/val_webgpu_test.cpp new file mode 100644 index 000000000..b65d08fe6 --- /dev/null +++ b/3rdparty/spirv-tools/test/val/val_webgpu_test.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for WebGPU env specific checks + +#include + +#include "gmock/gmock.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using testing::HasSubstr; + +using ValidateWebGPU = spvtest::ValidateBase; + +TEST_F(ValidateWebGPU, OpUndefIsDisallowed) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + %float = OpTypeFloat 32 + %1 = OpUndef %float +)"; + + CompileSuccessfully(spirv); + + // Control case: OpUndef is allowed in SPIR-V 1.3 + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + + // Control case: OpUndef is disallowed in the WebGPU env + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), HasSubstr("OpUndef is disallowed")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/3rdparty/spirv-tools/tools/CMakeLists.txt b/3rdparty/spirv-tools/tools/CMakeLists.txt index 5dacca38a..67143d879 100644 --- a/3rdparty/spirv-tools/tools/CMakeLists.txt +++ b/3rdparty/spirv-tools/tools/CMakeLists.txt @@ -40,12 +40,15 @@ endfunction() if (NOT ${SPIRV_SKIP_EXECUTABLES}) add_spvtools_tool(TARGET spirv-as SRCS as/as.cpp LIBS ${SPIRV_TOOLS}) add_spvtools_tool(TARGET spirv-dis SRCS dis/dis.cpp LIBS ${SPIRV_TOOLS}) - add_spvtools_tool(TARGET spirv-val SRCS val/val.cpp LIBS ${SPIRV_TOOLS}) - add_spvtools_tool(TARGET spirv-opt SRCS opt/opt.cpp LIBS SPIRV-Tools-opt ${SPIRV_TOOLS}) + add_spvtools_tool(TARGET spirv-val SRCS val/val.cpp util/cli_consumer.cpp LIBS ${SPIRV_TOOLS}) + add_spvtools_tool(TARGET spirv-opt SRCS opt/opt.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-opt ${SPIRV_TOOLS}) add_spvtools_tool(TARGET spirv-link SRCS link/linker.cpp LIBS SPIRV-Tools-link ${SPIRV_TOOLS}) add_spvtools_tool(TARGET spirv-stats SRCS stats/stats.cpp - stats/stats_analyzer.cpp + stats/stats_analyzer.cpp + stats/stats_analyzer.h + stats/spirv_stats.cpp + stats/spirv_stats.h LIBS ${SPIRV_TOOLS}) add_spvtools_tool(TARGET spirv-cfg SRCS cfg/cfg.cpp diff --git a/3rdparty/spirv-tools/tools/as/as.cpp b/3rdparty/spirv-tools/tools/as/as.cpp index 0d9363f77..287ba51f8 100644 --- a/3rdparty/spirv-tools/tools/as/as.cpp +++ b/3rdparty/spirv-tools/tools/as/as.cpp @@ -27,7 +27,7 @@ void print_usage(char* argv0) { Usage: %s [options] [] The SPIR-V assembly text is read from . If no file is specified, -or if the filename is "-", then the binary is read from standard input. +or if the filename is "-", then the assembly text is read from standard input. The SPIR-V binary module is written to file "out.spv", unless the -o option is used. diff --git a/3rdparty/spirv-tools/tools/cfg/bin_to_dot.cpp b/3rdparty/spirv-tools/tools/cfg/bin_to_dot.cpp index a7a2a2034..2561eea40 100644 --- a/3rdparty/spirv-tools/tools/cfg/bin_to_dot.cpp +++ b/3rdparty/spirv-tools/tools/cfg/bin_to_dot.cpp @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "bin_to_dot.h" +#include "tools/cfg/bin_to_dot.h" #include #include #include #include -#include "assembly_grammar.h" -#include "name_mapper.h" +#include "source/assembly_grammar.h" +#include "source/name_mapper.h" namespace { @@ -31,7 +31,7 @@ const char* kContinueStyle = "style=dotted"; // a SPIR-V module. class DotConverter { public: - DotConverter(libspirv::NameMapper name_mapper, std::iostream* out) + DotConverter(spvtools::NameMapper name_mapper, std::iostream* out) : name_mapper_(std::move(name_mapper)), out_(*out) {} // Emits the graph preamble. @@ -73,7 +73,7 @@ class DotConverter { uint32_t continue_target_ = 0; // An object for mapping Ids to names. - libspirv::NameMapper name_mapper_; + spvtools::NameMapper name_mapper_; // The output stream. std::ostream& out_; @@ -171,10 +171,10 @@ spv_result_t BinaryToDot(const spv_const_context context, const uint32_t* words, // Invalid arguments return error codes, but don't necessarily generate // diagnostics. These are programmer errors, not user errors. if (!diagnostic) return SPV_ERROR_INVALID_DIAGNOSTIC; - const libspirv::AssemblyGrammar grammar(context); + const spvtools::AssemblyGrammar grammar(context); if (!grammar.isValid()) return SPV_ERROR_INVALID_TABLE; - libspirv::FriendlyNameMapper friendly_mapper(context, words, num_words); + spvtools::FriendlyNameMapper friendly_mapper(context, words, num_words); DotConverter converter(friendly_mapper.GetNameMapper(), out); converter.Begin(); if (auto error = spvBinaryParse(context, &converter, words, num_words, diff --git a/3rdparty/spirv-tools/tools/cfg/bin_to_dot.h b/3rdparty/spirv-tools/tools/cfg/bin_to_dot.h index 1181b2521..4de2e07fa 100644 --- a/3rdparty/spirv-tools/tools/cfg/bin_to_dot.h +++ b/3rdparty/spirv-tools/tools/cfg/bin_to_dot.h @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef BIN_TO_DOT_H_ -#define BIN_TO_DOT_H_ +#ifndef TOOLS_CFG_BIN_TO_DOT_H_ +#define TOOLS_CFG_BIN_TO_DOT_H_ #include + #include "spirv-tools/libspirv.h" // Dumps the control flow graph for the given module to the output stream. @@ -24,4 +25,4 @@ spv_result_t BinaryToDot(const spv_const_context context, const uint32_t* words, size_t num_words, std::iostream* out, spv_diagnostic* diagnostic); -#endif // BIN_TO_DOT_H_ +#endif // TOOLS_CFG_BIN_TO_DOT_H_ diff --git a/3rdparty/spirv-tools/tools/cfg/cfg.cpp b/3rdparty/spirv-tools/tools/cfg/cfg.cpp index 2d0bcfae3..9e2c448ba 100644 --- a/3rdparty/spirv-tools/tools/cfg/cfg.cpp +++ b/3rdparty/spirv-tools/tools/cfg/cfg.cpp @@ -19,10 +19,9 @@ #include #include "spirv-tools/libspirv.h" +#include "tools/cfg/bin_to_dot.h" #include "tools/io.h" -#include "bin_to_dot.h" - // Prints a program usage message to stdout. static void print_usage(const char* argv0) { printf( diff --git a/3rdparty/spirv-tools/tools/comp/markv.cpp b/3rdparty/spirv-tools/tools/comp/markv.cpp index 216b83a4a..9a0a51808 100644 --- a/3rdparty/spirv-tools/tools/comp/markv.cpp +++ b/3rdparty/spirv-tools/tools/comp/markv.cpp @@ -19,13 +19,15 @@ #include #include #include +#include +#include #include -#include "markv_model_factory.h" #include "source/comp/markv.h" #include "source/spirv_target_env.h" #include "source/table.h" #include "spirv-tools/optimizer.hpp" +#include "tools/comp/markv_model_factory.h" #include "tools/io.h" namespace { @@ -138,7 +140,8 @@ int main(int argc, char** argv) { bool want_comments = false; bool validate_spirv_binary = false; - spvtools::MarkvModelType model_type = spvtools::kMarkvModelUnknown; + spvtools::comp::MarkvModelType model_type = + spvtools::comp::kMarkvModelUnknown; for (int argi = 2; argi < argc; ++argi) { if ('-' == argv[argi][0]) { @@ -167,17 +170,17 @@ int main(int argc, char** argv) { } else if (0 == strcmp(argv[argi], "--validate")) { validate_spirv_binary = true; } else if (0 == strcmp(argv[argi], "--model=shader_lite")) { - if (model_type != spvtools::kMarkvModelUnknown) + if (model_type != spvtools::comp::kMarkvModelUnknown) fprintf(stderr, "error: More than one model specified\n"); - model_type = spvtools::kMarkvModelShaderLite; + model_type = spvtools::comp::kMarkvModelShaderLite; } else if (0 == strcmp(argv[argi], "--model=shader_mid")) { - if (model_type != spvtools::kMarkvModelUnknown) + if (model_type != spvtools::comp::kMarkvModelUnknown) fprintf(stderr, "error: More than one model specified\n"); - model_type = spvtools::kMarkvModelShaderMid; + model_type = spvtools::comp::kMarkvModelShaderMid; } else if (0 == strcmp(argv[argi], "--model=shader_max")) { - if (model_type != spvtools::kMarkvModelUnknown) + if (model_type != spvtools::comp::kMarkvModelUnknown) fprintf(stderr, "error: More than one model specified\n"); - model_type = spvtools::kMarkvModelShaderMax; + model_type = spvtools::comp::kMarkvModelShaderMax; } else { print_usage(argv[0]); return 1; @@ -206,34 +209,34 @@ int main(int argc, char** argv) { } } - if (model_type == spvtools::kMarkvModelUnknown) - model_type = spvtools::kMarkvModelShaderLite; + if (model_type == spvtools::comp::kMarkvModelUnknown) + model_type = spvtools::comp::kMarkvModelShaderLite; - const auto no_comments = spvtools::MarkvLogConsumer(); + const auto no_comments = spvtools::comp::MarkvLogConsumer(); const auto output_to_stderr = [](const std::string& str) { std::cerr << str; }; ScopedContext ctx(kSpvEnv); - std::unique_ptr model = - spvtools::CreateMarkvModel(model_type); + std::unique_ptr model = + spvtools::comp::CreateMarkvModel(model_type); std::vector spirv; std::vector markv; - spvtools::MarkvCodecOptions options; + spvtools::comp::MarkvCodecOptions options; options.validate_spirv_binary = validate_spirv_binary; if (task == kEncode) { if (!ReadFile(input_filename, "rb", &spirv)) return 1; assert(!spirv.empty()); - if (SPV_SUCCESS != - spvtools::SpirvToMarkv(ctx.context, spirv, options, *model, - DiagnosticsMessageHandler, - want_comments ? output_to_stderr : no_comments, - spvtools::MarkvDebugConsumer(), &markv)) { + if (SPV_SUCCESS != spvtools::comp::SpirvToMarkv( + ctx.context, spirv, options, *model, + DiagnosticsMessageHandler, + want_comments ? output_to_stderr : no_comments, + spvtools::comp::MarkvDebugConsumer(), &markv)) { std::cerr << "error: Failed to encode " << input_filename << " to MARK-V " << std::endl; return 1; @@ -245,11 +248,11 @@ int main(int argc, char** argv) { if (!ReadFile(input_filename, "rb", &markv)) return 1; assert(!markv.empty()); - if (SPV_SUCCESS != - spvtools::MarkvToSpirv(ctx.context, markv, options, *model, - DiagnosticsMessageHandler, - want_comments ? output_to_stderr : no_comments, - spvtools::MarkvDebugConsumer(), &spirv)) { + if (SPV_SUCCESS != spvtools::comp::MarkvToSpirv( + ctx.context, markv, options, *model, + DiagnosticsMessageHandler, + want_comments ? output_to_stderr : no_comments, + spvtools::comp::MarkvDebugConsumer(), &spirv)) { std::cerr << "error: Failed to decode " << input_filename << " to SPIR-V " << std::endl; return 1; @@ -285,11 +288,11 @@ int main(int argc, char** argv) { return true; }; - if (SPV_SUCCESS != - spvtools::SpirvToMarkv(ctx.context, spirv_before, options, *model, - DiagnosticsMessageHandler, - want_comments ? output_to_stderr : no_comments, - encoder_debug_consumer, &markv)) { + if (SPV_SUCCESS != spvtools::comp::SpirvToMarkv( + ctx.context, spirv_before, options, *model, + DiagnosticsMessageHandler, + want_comments ? output_to_stderr : no_comments, + encoder_debug_consumer, &markv)) { std::cerr << "error: Failed to encode " << input_filename << " to MARK-V " << std::endl; return 1; @@ -355,7 +358,7 @@ int main(int argc, char** argv) { }; std::vector spirv_after; - const spv_result_t decoding_result = spvtools::MarkvToSpirv( + const spv_result_t decoding_result = spvtools::comp::MarkvToSpirv( ctx.context, markv, options, *model, DiagnosticsMessageHandler, want_comments ? output_to_stderr : no_comments, decoder_debug_consumer, &spirv_after); diff --git a/3rdparty/spirv-tools/tools/comp/markv_model_factory.cpp b/3rdparty/spirv-tools/tools/comp/markv_model_factory.cpp index ce190417c..863fcf558 100644 --- a/3rdparty/spirv-tools/tools/comp/markv_model_factory.cpp +++ b/3rdparty/spirv-tools/tools/comp/markv_model_factory.cpp @@ -12,25 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "markv_model_factory.h" +#include "tools/comp/markv_model_factory.h" -#include "markv_model_shader.h" +#include "source/util/make_unique.h" +#include "tools/comp/markv_model_shader.h" namespace spvtools { +namespace comp { std::unique_ptr CreateMarkvModel(MarkvModelType type) { std::unique_ptr model; switch (type) { case kMarkvModelShaderLite: { - model.reset(new MarkvModelShaderLite()); + model = MakeUnique(); break; } case kMarkvModelShaderMid: { - model.reset(new MarkvModelShaderMid()); + model = MakeUnique(); break; } case kMarkvModelShaderMax: { - model.reset(new MarkvModelShaderMax()); + model = MakeUnique(); break; } case kMarkvModelUnknown: { @@ -44,4 +46,5 @@ std::unique_ptr CreateMarkvModel(MarkvModelType type) { return model; } +} // namespace comp } // namespace spvtools diff --git a/3rdparty/spirv-tools/tools/comp/markv_model_factory.h b/3rdparty/spirv-tools/tools/comp/markv_model_factory.h index b0bf5e7a4..c13898b98 100644 --- a/3rdparty/spirv-tools/tools/comp/markv_model_factory.h +++ b/3rdparty/spirv-tools/tools/comp/markv_model_factory.h @@ -12,14 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SPIRV_TOOLS_COMP_MARKV_MODEL_FACTORY_H_ -#define SPIRV_TOOLS_COMP_MARKV_MODEL_FACTORY_H_ +#ifndef TOOLS_COMP_MARKV_MODEL_FACTORY_H_ +#define TOOLS_COMP_MARKV_MODEL_FACTORY_H_ #include #include "source/comp/markv_model.h" namespace spvtools { +namespace comp { enum MarkvModelType { kMarkvModelUnknown = 0, @@ -30,6 +31,7 @@ enum MarkvModelType { std::unique_ptr CreateMarkvModel(MarkvModelType type); +} // namespace comp } // namespace spvtools -#endif // SPIRV_TOOLS_COMP_MARKV_MODEL_FACTORY_H_ +#endif // TOOLS_COMP_MARKV_MODEL_FACTORY_H_ diff --git a/3rdparty/spirv-tools/tools/comp/markv_model_shader.cpp b/3rdparty/spirv-tools/tools/comp/markv_model_shader.cpp index f96c277d0..8e296cd8c 100644 --- a/3rdparty/spirv-tools/tools/comp/markv_model_shader.cpp +++ b/3rdparty/spirv-tools/tools/comp/markv_model_shader.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "markv_model_shader.h" +#include "tools/comp/markv_model_shader.h" #include #include @@ -21,10 +21,10 @@ #include #include -using spvutils::HuffmanCodec; +#include "source/util/make_unique.h" namespace spvtools { - +namespace comp { namespace { // Signals that the value is not in the coding scheme and a fallback method @@ -36,24 +36,7 @@ inline uint32_t CombineOpcodeAndNumOperands(uint32_t opcode, return opcode | (num_operands << 16); } -// The following file contains autogenerated statistical coding rules. -// Can be generated by running spirv-stats on representative corpus of shaders -// with flags: -// --codegen_opcode_and_num_operands_hist -// --codegen_opcode_and_num_operands_markov_huffman_codecs -// --codegen_literal_string_huffman_codecs -// --codegen_non_id_word_huffman_codecs -// --codegen_id_descriptor_huffman_codecs -// -// Example: -// find -type f -print0 | xargs -0 -s 2000000 -// ~/SPIRV-Tools/build/tools/spirv-stats -v -// --codegen_opcode_and_num_operands_hist -// --codegen_opcode_and_num_operands_markov_huffman_codecs -// --codegen_literal_string_huffman_codecs --codegen_non_id_word_huffman_codecs -// --codegen_id_descriptor_huffman_codecs -o -// ~/SPIRV-Tools/source/comp/markv_autogen.inc -#include "markv_model_shader_default_autogen.inc" +#include "tools/comp/markv_model_shader_default_autogen.inc" } // namespace @@ -61,8 +44,8 @@ MarkvModelShaderLite::MarkvModelShaderLite() { const uint16_t kVersionNumber = 1; SetModelVersion(kVersionNumber); - opcode_and_num_operands_huffman_codec_.reset( - new HuffmanCodec(GetOpcodeAndNumOperandsHist())); + opcode_and_num_operands_huffman_codec_ = + MakeUnique>(GetOpcodeAndNumOperandsHist()); id_fallback_strategy_ = IdFallbackStrategy::kShortDescriptor; } @@ -71,8 +54,8 @@ MarkvModelShaderMid::MarkvModelShaderMid() { const uint16_t kVersionNumber = 1; SetModelVersion(kVersionNumber); - opcode_and_num_operands_huffman_codec_.reset( - new HuffmanCodec(GetOpcodeAndNumOperandsHist())); + opcode_and_num_operands_huffman_codec_ = + MakeUnique>(GetOpcodeAndNumOperandsHist()); non_id_word_huffman_codecs_ = GetNonIdWordHuffmanCodecs(); id_descriptor_huffman_codecs_ = GetIdDescriptorHuffmanCodecs(); descriptors_with_coding_scheme_ = GetDescriptorsWithCodingScheme(); @@ -85,8 +68,8 @@ MarkvModelShaderMax::MarkvModelShaderMax() { const uint16_t kVersionNumber = 1; SetModelVersion(kVersionNumber); - opcode_and_num_operands_huffman_codec_.reset( - new HuffmanCodec(GetOpcodeAndNumOperandsHist())); + opcode_and_num_operands_huffman_codec_ = + MakeUnique>(GetOpcodeAndNumOperandsHist()); opcode_and_num_operands_markov_huffman_codecs_ = GetOpcodeAndNumOperandsMarkovHuffmanCodecs(); non_id_word_huffman_codecs_ = GetNonIdWordHuffmanCodecs(); @@ -97,4 +80,5 @@ MarkvModelShaderMax::MarkvModelShaderMax() { id_fallback_strategy_ = IdFallbackStrategy::kRuleBased; } +} // namespace comp } // namespace spvtools diff --git a/3rdparty/spirv-tools/tools/comp/markv_model_shader.h b/3rdparty/spirv-tools/tools/comp/markv_model_shader.h index f45a86949..3a704571f 100644 --- a/3rdparty/spirv-tools/tools/comp/markv_model_shader.h +++ b/3rdparty/spirv-tools/tools/comp/markv_model_shader.h @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SPIRV_TOOLS_MARKV_MODEL_SHADER_H_ -#define SPIRV_TOOLS_MARKV_MODEL_SHADER_H_ +#ifndef TOOLS_COMP_MARKV_MODEL_SHADER_H_ +#define TOOLS_COMP_MARKV_MODEL_SHADER_H_ #include "source/comp/markv_model.h" namespace spvtools { +namespace comp { // MARK-V shader compression model, which only uses fast and lightweight // algorithms, which do not require training and are not heavily dependent on @@ -40,6 +41,7 @@ class MarkvModelShaderMax : public MarkvModel { MarkvModelShaderMax(); }; +} // namespace comp } // namespace spvtools -#endif // SPIRV_TOOLS_MARKV_MODEL_SHADER_H_ +#endif // TOOLS_COMP_MARKV_MODEL_SHADER_H_ diff --git a/3rdparty/spirv-tools/tools/io.h b/3rdparty/spirv-tools/tools/io.h index 690b5c368..aaf8fcdd2 100644 --- a/3rdparty/spirv-tools/tools/io.h +++ b/3rdparty/spirv-tools/tools/io.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_TOOLS_IO_H_ -#define LIBSPIRV_TOOLS_IO_H_ +#ifndef TOOLS_IO_H_ +#define TOOLS_IO_H_ #include #include @@ -21,13 +21,15 @@ // Appends the content from the file named as |filename| to |data|, assuming // each element in the file is of type |T|. The file is opened with the given -// |mode|. If |filename| is nullptr or "-", reads from the standard input. If -// any error occurs, writes error messages to standard error and returns false. +// |mode|. If |filename| is nullptr or "-", reads from the standard input, but +// reopened with the given mode. If any error occurs, writes error messages to +// standard error and returns false. template bool ReadFile(const char* filename, const char* mode, std::vector* data) { const int buf_size = 1024; const bool use_file = filename && strcmp("-", filename); - if (FILE* fp = (use_file ? fopen(filename, mode) : stdin)) { + if (FILE* fp = + (use_file ? fopen(filename, mode) : freopen(nullptr, mode, stdin))) { T buf[buf_size]; while (size_t len = fread(buf, sizeof(T), buf_size, fp)) { data->insert(data->end(), buf, buf + len); @@ -39,7 +41,10 @@ bool ReadFile(const char* filename, const char* mode, std::vector* data) { } } else { if (sizeof(T) != 1 && (ftell(fp) % sizeof(T))) { - fprintf(stderr, "error: corrupted word found in file '%s'\n", filename); + fprintf( + stderr, + "error: file size should be a multiple of %zd; file '%s' corrupt\n", + sizeof(T), filename); return false; } } @@ -74,4 +79,4 @@ bool WriteFile(const char* filename, const char* mode, const T* data, return true; } -#endif // LIBSPIRV_TOOLS_IO_H_ +#endif // TOOLS_IO_H_ diff --git a/3rdparty/spirv-tools/tools/opt/opt.cpp b/3rdparty/spirv-tools/tools/opt/opt.cpp index 110e4b968..fcd260e45 100644 --- a/3rdparty/spirv-tools/tools/opt/opt.cpp +++ b/3rdparty/spirv-tools/tools/opt/opt.cpp @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include #include @@ -20,15 +19,16 @@ #include #include #include +#include #include -#include "opt/set_spec_constant_default_value_pass.h" +#include "source/opt/log.h" +#include "source/opt/loop_peeling.h" +#include "source/opt/set_spec_constant_default_value_pass.h" +#include "source/spirv_validator_options.h" #include "spirv-tools/optimizer.hpp" - -#include "message.h" #include "tools/io.h" - -using namespace spvtools; +#include "tools/util/cli_consumer.h" namespace { @@ -40,6 +40,17 @@ struct OptStatus { int code; }; +// Message consumer for this tool. Used to emit diagnostics during +// initialization and setup. Note that |source| and |position| are irrelevant +// here because we are still not processing a SPIR-V input file. +void opt_diagnostic(spv_message_level_t level, const char* /*source*/, + const spv_position_t& /*positon*/, const char* message) { + if (level == SPV_MSG_ERROR) { + fprintf(stderr, "error: "); + } + fprintf(stderr, "%s\n", message); +} + std::string GetListOfPassesAsString(const spvtools::Optimizer& optimizer) { std::stringstream ss; for (const auto& name : optimizer.GetPassNames()) { @@ -93,6 +104,9 @@ Options (in lexicographical order): Cleanup the control flow graph. This will remove any unnecessary code from the CFG like unreachable code. Performed on entry point call tree functions and exported functions. + --combine-access-chains + Combines chained access chains to produce a single instruction + where possible. --compact-ids Remap result ids to a compact range starting from %%1 and without any gaps. @@ -124,16 +138,16 @@ Options (in lexicographical order): --eliminate-dead-functions Deletes functions that cannot be reached from entry points or exported functions. - --eliminate-dead-insert + --eliminate-dead-inserts Deletes unreferenced inserts into composites, most notably unused stores to vector components, that are not removed by aggressive dead code elimination. --eliminate-dead-variables Deletes module scope variables that are not referenced. --eliminate-insert-extract - Replace extract from a sequence of inserts with the - corresponding value. Performed only on entry point call tree - functions. + DEPRECATED. This pass has been replaced by the simplification + pass, and that pass will be run instead. + See --simplify-instructions. --eliminate-local-multi-store Replace stores and loads of function scope variables that are stored multiple times. Performed on variables referenceed only @@ -166,21 +180,40 @@ Options (in lexicographical order): early return in a loop. --legalize-hlsl Runs a series of optimizations that attempts to take SPIR-V - generated by and HLSL front-end and generate legal Vulkan SPIR-V. + generated by an HLSL front-end and generates legal Vulkan SPIR-V. The optimizations are: %s - Note this does not guarantee legal code. This option implies - --skip-validation. + Note this does not guarantee legal code. This option passes the + option --relax-logical-pointer to the validator. --local-redundancy-elimination Looks for instructions in the same basic block that compute the same value, and deletes the redundant ones. + --loop-fission + Splits any top level loops in which the register pressure has exceeded + a given threshold. The threshold must follow the use of this flag and + must be a positive integer value. + --loop-fusion + Identifies adjacent loops with the same lower and upper bound. + If this is legal, then merge the loops into a single loop. + Includes heuristics to ensure it does not increase number of + registers too much, while reducing the number of loads from + memory. Takes an additional positive integer argument to set + the maximum number of registers. --loop-unroll Fully unrolls loops marked with the Unroll flag --loop-unroll-partial Partially unrolls loops marked with the Unroll flag. Takes an additional non-0 integer argument to set the unroll factor, or how many times a loop body should be duplicated + --loop-peeling + Execute few first (respectively last) iterations before + (respectively after) the loop if it can elide some branches. + --loop-peeling-threshold + Takes a non-0 integer argument to set the loop peeling code size + growth threshold. The threshold prevents the loop peeling + from happening if the code size increase created by + the optimization is above the threshold. --merge-blocks Join two blocks into a single block if the second has the first as its only predecessor. Performed only on entry point @@ -252,9 +285,9 @@ Options (in lexicographical order): --private-to-local Change the scope of private variables that are used in a single function to that function. - --remove-duplicates - Removes duplicate types, decorations, capabilities and extension - instructions. + --reduce-load-size + Replaces loads of composite objects where not every component is + used by loads of just the elements that are used. --redundancy-elimination Looks for instructions in the same function that compute the same value, and deletes the redundant ones. @@ -262,6 +295,9 @@ Options (in lexicographical order): Allow store from one struct type to a different type with compatible layout and members. This option is forwarded to the validator. + --remove-duplicates + Removes duplicate types, decorations, capabilities and extension + instructions. --replace-invalid-opcode Replaces instructions whose opcode is valid for shader modules, but not for the current shader stage. To have an effect, all @@ -269,10 +305,12 @@ Options (in lexicographical order): --ssa-rewrite Replace loads and stores to function local variables with operations on SSA IDs. - --scalar-replacement + --scalar-replacement[=] Replace aggregate function scope variables that are only accessed via their elements with new function variables representing each - element. + element. is a limit on the size of the aggragates that will + be replaced. 0 means there is no limit. The default value is + 100. --set-spec-const-default-value ": ..." Set the default values of the specialization constants with : pairs specified in a double-quoted @@ -301,6 +339,10 @@ Options (in lexicographical order): prints CPU/WALL/USR/SYS time (and RSS if possible), but note that USR/SYS time are returned by getrusage() and can have a small error. + --vector-dce + This pass looks for components of vectors that are unused, and + removes them from the vector. Note this would still leave around + lots of dead code that a pass of ADCE will be able to remove. --workaround-1209 Rewrites instructions for which there are known driver bugs to avoid triggering those bugs. @@ -327,7 +369,8 @@ bool ReadFlagsFromFile(const char* oconfig_flag, std::vector* file_flags) { const char* fname = strchr(oconfig_flag, '='); if (fname == nullptr || fname[0] != '=') { - fprintf(stderr, "error: Invalid -Oconfig flag %s\n", oconfig_flag); + spvtools::Errorf(opt_diagnostic, nullptr, {}, "Invalid -Oconfig flag %s", + oconfig_flag); return false; } fname++; @@ -335,14 +378,25 @@ bool ReadFlagsFromFile(const char* oconfig_flag, std::ifstream input_file; input_file.open(fname); if (input_file.fail()) { - fprintf(stderr, "error: Could not open file '%s'\n", fname); + spvtools::Errorf(opt_diagnostic, nullptr, {}, "Could not open file '%s'", + fname); return false; } - while (!input_file.eof()) { - std::string flag; - input_file >> flag; - if (flag.length() > 0 && flag[0] != '#') { + std::string line; + while (std::getline(input_file, line)) { + // Ignore empty lines and lines starting with the comment marker '#'. + if (line.length() == 0 || line[0] == '#') { + continue; + } + + // Tokenize the line. Add all found tokens to the list of found flags. This + // mimics the way the shell will parse whitespace on the command line. NOTE: + // This does not support quoting and it is not intended to. + std::istringstream iss(line); + while (!iss.eof()) { + std::string flag; + iss >> flag; file_flags->push_back(flag); } } @@ -350,9 +404,10 @@ bool ReadFlagsFromFile(const char* oconfig_flag, return true; } -OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer, - const char** in_file, const char** out_file, - spv_validator_options options, bool* skip_validator); +OptStatus ParseFlags(int argc, const char** argv, + spvtools::Optimizer* optimizer, const char** in_file, + const char** out_file, spvtools::ValidatorOptions* options, + bool* skip_validator); // Parses and handles the -Oconfig flag. |prog_name| contains the name of // the spirv-opt binary (used to build a new argv vector for the recursive @@ -361,15 +416,15 @@ OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer, // // This returns the same OptStatus instance returned by ParseFlags. OptStatus ParseOconfigFlag(const char* prog_name, const char* opt_flag, - Optimizer* optimizer, const char** in_file, + spvtools::Optimizer* optimizer, const char** in_file, const char** out_file) { std::vector flags; flags.push_back(prog_name); std::vector file_flags; if (!ReadFlagsFromFile(opt_flag, &file_flags)) { - fprintf(stderr, - "error: Could not read optimizer flags from configuration file\n"); + spvtools::Error(opt_diagnostic, nullptr, {}, + "Could not read optimizer flags from configuration file"); return {OPT_STOP, 1}; } flags.insert(flags.end(), file_flags.begin(), file_flags.end()); @@ -377,9 +432,9 @@ OptStatus ParseOconfigFlag(const char* prog_name, const char* opt_flag, const char** new_argv = new const char*[flags.size()]; for (size_t i = 0; i < flags.size(); i++) { if (flags[i].find("-Oconfig=") != std::string::npos) { - fprintf(stderr, - "error: Flag -Oconfig= may not be used inside the configuration " - "file\n"); + spvtools::Error( + opt_diagnostic, nullptr, {}, + "Flag -Oconfig= may not be used inside the configuration file"); return {OPT_STOP, 1}; } new_argv[i] = flags[i].c_str(); @@ -390,37 +445,65 @@ OptStatus ParseOconfigFlag(const char* prog_name, const char* opt_flag, in_file, out_file, nullptr, &skip_validator); } -OptStatus ParseLoopUnrollPartialArg(int argc, const char** argv, int argi, - Optimizer* optimizer) { - if (argi < argc) { - int factor = atoi(argv[argi]); - if (factor != 0) { - optimizer->RegisterPass(CreateLoopUnrollPass(false, factor)); - return {OPT_CONTINUE, 0}; +// Canonicalize the flag in |argv[argi]| of the form '--pass arg' into +// '--pass=arg'. The optimizer only accepts arguments to pass names that use the +// form '--pass_name=arg'. Since spirv-opt also accepts the other form, this +// function makes the necessary conversion. +// +// Pass flags that require additional arguments should be handled here. Note +// that additional arguments should be given as a single string. If the flag +// requires more than one argument, the pass creator in +// Optimizer::GetPassFromFlag() should parse it accordingly (e.g., see the +// handler for --set-spec-const-default-value). +// +// If the argument requests one of the passes that need an additional argument, +// |argi| is modified to point past the current argument, and the string +// "argv[argi]=argv[argi + 1]" is returned. Otherwise, |argi| is unmodified and +// the string "|argv[argi]|" is returned. +std::string CanonicalizeFlag(const char** argv, int argc, int* argi) { + const char* cur_arg = argv[*argi]; + const char* next_arg = (*argi + 1 < argc) ? argv[*argi + 1] : nullptr; + std::ostringstream canonical_arg; + canonical_arg << cur_arg; + + // NOTE: DO NOT ADD NEW FLAGS HERE. + // + // These flags are supported for backwards compatibility. When adding new + // passes that need extra arguments in its command-line flag, please make them + // use the syntax "--pass_name[=pass_arg]. + if (0 == strcmp(cur_arg, "--set-spec-const-default-value") || + 0 == strcmp(cur_arg, "--loop-fission") || + 0 == strcmp(cur_arg, "--loop-fusion") || + 0 == strcmp(cur_arg, "--loop-unroll-partial") || + 0 == strcmp(cur_arg, "--loop-peeling-threshold")) { + if (next_arg) { + canonical_arg << "=" << next_arg; + ++(*argi); } } - fprintf(stderr, - "error: --loop-unroll-partial must be followed by a non-0 " - "integer\n"); - return {OPT_STOP, 1}; + + return canonical_arg.str(); } -// Parses command-line flags. |argc| contains the number of command-line flags. -// |argv| points to an array of strings holding the flags. |optimizer| is the -// Optimizer instance used to optimize the program. +// the number of command-line flags. |argv| points to an array of strings +// holding the flags. |optimizer| is the Optimizer instance used to optimize the +// program. // // On return, this function stores the name of the input program in |in_file|. // The name of the output file in |out_file|. The return value indicates whether // optimization should continue and a status code indicating an error or // success. -OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer, - const char** in_file, const char** out_file, - spv_validator_options options, bool* skip_validator) { +OptStatus ParseFlags(int argc, const char** argv, + spvtools::Optimizer* optimizer, const char** in_file, + const char** out_file, spvtools::ValidatorOptions* options, + bool* skip_validator) { + std::vector pass_flags; for (int argi = 1; argi < argc; ++argi) { const char* cur_arg = argv[argi]; if ('-' == cur_arg[0]) { if (0 == strcmp(cur_arg, "--version")) { - printf("%s\n", spvSoftwareVersionDetailsString()); + spvtools::Logf(opt_diagnostic, SPV_MSG_INFO, nullptr, {}, "%s\n", + spvSoftwareVersionDetailsString()); return {OPT_STOP, 0}; } else if (0 == strcmp(cur_arg, "--help") || 0 == strcmp(cur_arg, "-h")) { PrintUsage(argv[0]); @@ -432,158 +515,55 @@ OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer, PrintUsage(argv[0]); return {OPT_STOP, 1}; } - } else if (0 == strcmp(cur_arg, "--strip-debug")) { - optimizer->RegisterPass(CreateStripDebugInfoPass()); - } else if (0 == strcmp(cur_arg, "--strip-reflect")) { - optimizer->RegisterPass(CreateStripReflectInfoPass()); - } else if (0 == strcmp(cur_arg, "--set-spec-const-default-value")) { - if (++argi < argc) { - auto spec_ids_vals = - opt::SetSpecConstantDefaultValuePass::ParseDefaultValuesString( - argv[argi]); - if (!spec_ids_vals) { - fprintf(stderr, - "error: Invalid argument for " - "--set-spec-const-default-value: %s\n", - argv[argi]); - return {OPT_STOP, 1}; - } - optimizer->RegisterPass( - CreateSetSpecConstantDefaultValuePass(std::move(*spec_ids_vals))); + } else if ('\0' == cur_arg[1]) { + // Setting a filename of "-" to indicate stdin. + if (!*in_file) { + *in_file = cur_arg; } else { - fprintf( - stderr, - "error: Expected a string of : pairs."); + spvtools::Error(opt_diagnostic, nullptr, {}, + "More than one input file specified"); return {OPT_STOP, 1}; } - } else if (0 == strcmp(cur_arg, "--if-conversion")) { - optimizer->RegisterPass(CreateIfConversionPass()); - } else if (0 == strcmp(cur_arg, "--freeze-spec-const")) { - optimizer->RegisterPass(CreateFreezeSpecConstantValuePass()); - } else if (0 == strcmp(cur_arg, "--inline-entry-points-exhaustive")) { - optimizer->RegisterPass(CreateInlineExhaustivePass()); - } else if (0 == strcmp(cur_arg, "--inline-entry-points-opaque")) { - optimizer->RegisterPass(CreateInlineOpaquePass()); - } else if (0 == strcmp(cur_arg, "--convert-local-access-chains")) { - optimizer->RegisterPass(CreateLocalAccessChainConvertPass()); - } else if (0 == strcmp(cur_arg, "--eliminate-dead-code-aggressive")) { - optimizer->RegisterPass(CreateAggressiveDCEPass()); - } else if (0 == strcmp(cur_arg, "--eliminate-insert-extract")) { - optimizer->RegisterPass(CreateInsertExtractElimPass()); - } else if (0 == strcmp(cur_arg, "--eliminate-local-single-block")) { - optimizer->RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()); - } else if (0 == strcmp(cur_arg, "--eliminate-local-single-store")) { - optimizer->RegisterPass(CreateLocalSingleStoreElimPass()); - } else if (0 == strcmp(cur_arg, "--merge-blocks")) { - optimizer->RegisterPass(CreateBlockMergePass()); - } else if (0 == strcmp(cur_arg, "--merge-return")) { - optimizer->RegisterPass(CreateMergeReturnPass()); - } else if (0 == strcmp(cur_arg, "--eliminate-dead-branches")) { - optimizer->RegisterPass(CreateDeadBranchElimPass()); - } else if (0 == strcmp(cur_arg, "--eliminate-dead-functions")) { - optimizer->RegisterPass(CreateEliminateDeadFunctionsPass()); - } else if (0 == strcmp(cur_arg, "--eliminate-local-multi-store")) { - optimizer->RegisterPass(CreateLocalMultiStoreElimPass()); - } else if (0 == strcmp(cur_arg, "--eliminate-common-uniform")) { - optimizer->RegisterPass(CreateCommonUniformElimPass()); - } else if (0 == strcmp(cur_arg, "--eliminate-dead-const")) { - optimizer->RegisterPass(CreateEliminateDeadConstantPass()); - } else if (0 == strcmp(cur_arg, "--eliminate-dead-inserts")) { - optimizer->RegisterPass(CreateDeadInsertElimPass()); - } else if (0 == strcmp(cur_arg, "--eliminate-dead-variables")) { - optimizer->RegisterPass(CreateDeadVariableEliminationPass()); - } else if (0 == strcmp(cur_arg, "--fold-spec-const-op-composite")) { - optimizer->RegisterPass(CreateFoldSpecConstantOpAndCompositePass()); - } else if (0 == strcmp(cur_arg, "--loop-unswitch")) { - optimizer->RegisterPass(CreateLoopUnswitchPass()); - } else if (0 == strcmp(cur_arg, "--scalar-replacement")) { - optimizer->RegisterPass(CreateScalarReplacementPass()); - } else if (0 == strcmp(cur_arg, "--strength-reduction")) { - optimizer->RegisterPass(CreateStrengthReductionPass()); - } else if (0 == strcmp(cur_arg, "--unify-const")) { - optimizer->RegisterPass(CreateUnifyConstantPass()); - } else if (0 == strcmp(cur_arg, "--flatten-decorations")) { - optimizer->RegisterPass(CreateFlattenDecorationPass()); - } else if (0 == strcmp(cur_arg, "--compact-ids")) { - optimizer->RegisterPass(CreateCompactIdsPass()); - } else if (0 == strcmp(cur_arg, "--cfg-cleanup")) { - optimizer->RegisterPass(CreateCFGCleanupPass()); - } else if (0 == strcmp(cur_arg, "--local-redundancy-elimination")) { - optimizer->RegisterPass(CreateLocalRedundancyEliminationPass()); - } else if (0 == strcmp(cur_arg, "--loop-invariant-code-motion")) { - optimizer->RegisterPass(CreateLoopInvariantCodeMotionPass()); - } else if (0 == strcmp(cur_arg, "--redundancy-elimination")) { - optimizer->RegisterPass(CreateRedundancyEliminationPass()); - } else if (0 == strcmp(cur_arg, "--private-to-local")) { - optimizer->RegisterPass(CreatePrivateToLocalPass()); - } else if (0 == strcmp(cur_arg, "--remove-duplicates")) { - optimizer->RegisterPass(CreateRemoveDuplicatesPass()); - } else if (0 == strcmp(cur_arg, "--workaround-1209")) { - optimizer->RegisterPass(CreateWorkaround1209Pass()); - } else if (0 == strcmp(cur_arg, "--relax-struct-store")) { - options->relax_struct_store = true; - } else if (0 == strcmp(cur_arg, "--replace-invalid-opcode")) { - optimizer->RegisterPass(CreateReplaceInvalidOpcodePass()); - } else if (0 == strcmp(cur_arg, "--simplify-instructions")) { - optimizer->RegisterPass(CreateSimplificationPass()); - } else if (0 == strcmp(cur_arg, "--ssa-rewrite")) { - optimizer->RegisterPass(CreateSSARewritePass()); - } else if (0 == strcmp(cur_arg, "--copy-propagate-arrays")) { - optimizer->RegisterPass(CreateCopyPropagateArraysPass()); - } else if (0 == strcmp(cur_arg, "--loop-unroll")) { - optimizer->RegisterPass(CreateLoopUnrollPass(true)); - } else if (0 == strcmp(cur_arg, "--loop-unroll-partial")) { - OptStatus status = - ParseLoopUnrollPartialArg(argc, argv, ++argi, optimizer); - if (status.action != OPT_CONTINUE) { - return status; - } - } else if (0 == strcmp(cur_arg, "--skip-validation")) { - *skip_validator = true; - } else if (0 == strcmp(cur_arg, "-O")) { - optimizer->RegisterPerformancePasses(); - } else if (0 == strcmp(cur_arg, "-Os")) { - optimizer->RegisterSizePasses(); - } else if (0 == strcmp(cur_arg, "--legalize-hlsl")) { - *skip_validator = true; - optimizer->RegisterLegalizationPasses(); } else if (0 == strncmp(cur_arg, "-Oconfig=", sizeof("-Oconfig=") - 1)) { OptStatus status = ParseOconfigFlag(argv[0], cur_arg, optimizer, in_file, out_file); if (status.action != OPT_CONTINUE) { return status; } - } else if (0 == strcmp(cur_arg, "--ccp")) { - optimizer->RegisterPass(CreateCCPPass()); + } else if (0 == strcmp(cur_arg, "--skip-validation")) { + *skip_validator = true; } else if (0 == strcmp(cur_arg, "--print-all")) { optimizer->SetPrintAll(&std::cerr); } else if (0 == strcmp(cur_arg, "--time-report")) { optimizer->SetTimeReport(&std::cerr); - } else if ('\0' == cur_arg[1]) { - // Setting a filename of "-" to indicate stdin. - if (!*in_file) { - *in_file = cur_arg; - } else { - fprintf(stderr, "error: More than one input file specified\n"); - return {OPT_STOP, 1}; - } + } else if (0 == strcmp(cur_arg, "--relax-struct-store")) { + options->SetRelaxStructStore(true); } else { - fprintf( - stderr, - "error: Unknown flag '%s'. Use --help for a list of valid flags\n", - cur_arg); - return {OPT_STOP, 1}; + // Some passes used to accept the form '--pass arg', canonicalize them + // to '--pass=arg'. + pass_flags.push_back(CanonicalizeFlag(argv, argc, &argi)); + + // If we were requested to legalize SPIR-V generated from the HLSL + // front-end, skip validation. + if (0 == strcmp(cur_arg, "--legalize-hlsl")) { + options->SetRelaxLogicalPointer(true); + } } } else { if (!*in_file) { *in_file = cur_arg; } else { - fprintf(stderr, "error: More than one input file specified\n"); + spvtools::Error(opt_diagnostic, nullptr, {}, + "More than one input file specified"); return {OPT_STOP, 1}; } } } + if (!optimizer->RegisterPassesFromFlags(pass_flags)) { + return {OPT_STOP, 1}; + } + return {OPT_CONTINUE, 0}; } @@ -595,25 +575,20 @@ int main(int argc, const char** argv) { bool skip_validator = false; spv_target_env target_env = kDefaultEnvironment; - spv_validator_options options = spvValidatorOptionsCreate(); + spvtools::ValidatorOptions options; spvtools::Optimizer optimizer(target_env); - optimizer.SetMessageConsumer([](spv_message_level_t level, const char* source, - const spv_position_t& position, - const char* message) { - std::cerr << StringifyMessage(level, source, position, message) - << std::endl; - }); + optimizer.SetMessageConsumer(spvtools::utils::CLIMessageConsumer); OptStatus status = ParseFlags(argc, argv, &optimizer, &in_file, &out_file, - options, &skip_validator); + &options, &skip_validator); if (status.action == OPT_STOP) { return status.code; } if (out_file == nullptr) { - fprintf(stderr, "error: -o required\n"); + spvtools::Error(opt_diagnostic, nullptr, {}, "-o required"); return 1; } @@ -622,28 +597,10 @@ int main(int argc, const char** argv) { return 1; } - if (!skip_validator) { - // Let's do validation first. - spv_context context = spvContextCreate(target_env); - spv_diagnostic diagnostic = nullptr; - spv_const_binary_t binary_struct = {binary.data(), binary.size()}; - spv_result_t error = - spvValidateWithOptions(context, options, &binary_struct, &diagnostic); - if (error) { - spvDiagnosticPrint(diagnostic); - spvDiagnosticDestroy(diagnostic); - spvValidatorOptionsDestroy(options); - spvContextDestroy(context); - return error; - } - spvDiagnosticDestroy(diagnostic); - spvValidatorOptionsDestroy(options); - spvContextDestroy(context); - } - // By using the same vector as input and output, we save time in the case // that there was no change. - bool ok = optimizer.Run(binary.data(), binary.size(), &binary); + bool ok = optimizer.Run(binary.data(), binary.size(), &binary, options, + skip_validator); if (!WriteFile(out_file, "wb", binary.data(), binary.size())) { return 1; diff --git a/3rdparty/spirv-tools/tools/stats/spirv_stats.cpp b/3rdparty/spirv-tools/tools/stats/spirv_stats.cpp new file mode 100644 index 000000000..7751c6402 --- /dev/null +++ b/3rdparty/spirv-tools/tools/stats/spirv_stats.cpp @@ -0,0 +1,165 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/stats/spirv_stats.h" + +#include + +#include +#include +#include + +#include "source/diagnostic.h" +#include "source/enum_string_mapping.h" +#include "source/extensions.h" +#include "source/id_descriptor.h" +#include "source/instruction.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/val/instruction.h" +#include "source/val/validate.h" +#include "source/val/validation_state.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace stats { +namespace { + +// Helper class for stats aggregation. Receives as in/out parameter. +// Constructs ValidationState and updates it by running validator for each +// instruction. +class StatsAggregator { + public: + StatsAggregator(SpirvStats* in_out_stats, const val::ValidationState_t* state) + : stats_(in_out_stats), vstate_(state) {} + + // Processes the instructions to collect stats. + void aggregate() { + const auto& instructions = vstate_->ordered_instructions(); + + ++stats_->version_hist[vstate_->version()]; + ++stats_->generator_hist[vstate_->generator()]; + + for (size_t i = 0; i < instructions.size(); ++i) { + const auto& inst = instructions[i]; + + ProcessOpcode(&inst, i); + ProcessCapability(&inst); + ProcessExtension(&inst); + ProcessConstant(&inst); + } + } + + // Collects OpCapability statistics. + void ProcessCapability(const val::Instruction* inst) { + if (inst->opcode() != SpvOpCapability) return; + const uint32_t capability = inst->word(inst->operands()[0].offset); + ++stats_->capability_hist[capability]; + } + + // Collects OpExtension statistics. + void ProcessExtension(const val::Instruction* inst) { + if (inst->opcode() != SpvOpExtension) return; + const std::string extension = GetExtensionString(&inst->c_inst()); + ++stats_->extension_hist[extension]; + } + + // Collects OpCode statistics. + void ProcessOpcode(const val::Instruction* inst, size_t idx) { + const SpvOp opcode = inst->opcode(); + ++stats_->opcode_hist[opcode]; + + if (idx == 0) return; + + --idx; + + const auto& instructions = vstate_->ordered_instructions(); + + auto step_it = stats_->opcode_markov_hist.begin(); + for (; step_it != stats_->opcode_markov_hist.end(); --idx, ++step_it) { + auto& hist = (*step_it)[instructions[idx].opcode()]; + ++hist[opcode]; + + if (idx == 0) break; + } + } + + // Collects OpConstant statistics. + void ProcessConstant(const val::Instruction* inst) { + if (inst->opcode() != SpvOpConstant) return; + + const uint32_t type_id = inst->GetOperandAs(0); + const auto type_decl_it = vstate_->all_definitions().find(type_id); + assert(type_decl_it != vstate_->all_definitions().end()); + + const val::Instruction& type_decl_inst = *type_decl_it->second; + const SpvOp type_op = type_decl_inst.opcode(); + if (type_op == SpvOpTypeInt) { + const uint32_t bit_width = type_decl_inst.GetOperandAs(1); + const uint32_t is_signed = type_decl_inst.GetOperandAs(2); + assert(is_signed == 0 || is_signed == 1); + if (bit_width == 16) { + if (is_signed) + ++stats_->s16_constant_hist[inst->GetOperandAs(2)]; + else + ++stats_->u16_constant_hist[inst->GetOperandAs(2)]; + } else if (bit_width == 32) { + if (is_signed) + ++stats_->s32_constant_hist[inst->GetOperandAs(2)]; + else + ++stats_->u32_constant_hist[inst->GetOperandAs(2)]; + } else if (bit_width == 64) { + if (is_signed) + ++stats_->s64_constant_hist[inst->GetOperandAs(2)]; + else + ++stats_->u64_constant_hist[inst->GetOperandAs(2)]; + } else { + assert(false && "TypeInt bit width is not 16, 32 or 64"); + } + } else if (type_op == SpvOpTypeFloat) { + const uint32_t bit_width = type_decl_inst.GetOperandAs(1); + if (bit_width == 32) { + ++stats_->f32_constant_hist[inst->GetOperandAs(2)]; + } else if (bit_width == 64) { + ++stats_->f64_constant_hist[inst->GetOperandAs(2)]; + } else { + assert(bit_width == 16); + } + } + } + + private: + SpirvStats* stats_; + const val::ValidationState_t* vstate_; + IdDescriptorCollection id_descriptors_; +}; + +} // namespace + +spv_result_t AggregateStats(const spv_context_t& context, const uint32_t* words, + const size_t num_words, spv_diagnostic* pDiagnostic, + SpirvStats* stats) { + std::unique_ptr vstate; + spv_validator_options_t options; + spv_result_t result = ValidateBinaryAndKeepValidationState( + &context, &options, words, num_words, pDiagnostic, &vstate); + if (result != SPV_SUCCESS) return result; + + StatsAggregator stats_aggregator(stats, vstate.get()); + stats_aggregator.aggregate(); + return SPV_SUCCESS; +} + +} // namespace stats +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/spirv_stats.h b/3rdparty/spirv-tools/tools/stats/spirv_stats.h similarity index 62% rename from 3rdparty/spirv-tools/source/spirv_stats.h rename to 3rdparty/spirv-tools/tools/stats/spirv_stats.h index cc6c23914..16e720fe3 100644 --- a/3rdparty/spirv-tools/source/spirv_stats.h +++ b/3rdparty/spirv-tools/tools/stats/spirv_stats.h @@ -12,17 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_SPIRV_STATS_H_ -#define LIBSPIRV_SPIRV_STATS_H_ +#ifndef TOOLS_STATS_SPIRV_STATS_H_ +#define TOOLS_STATS_SPIRV_STATS_H_ #include #include #include +#include #include #include "spirv-tools/libspirv.hpp" -namespace libspirv { +namespace spvtools { +namespace stats { struct SpirvStats { // Version histogram, version_word -> count. @@ -40,10 +42,6 @@ struct SpirvStats { // Opcode histogram, SpvOpXXX -> count. std::unordered_map opcode_hist; - // Histogram of words combining opcode and number of operands, - // opcode | (num_operands << 16) -> count. - std::unordered_map opcode_and_num_operands_hist; - // OpConstant u16 histogram, value -> count. std::unordered_map u16_constant_hist; @@ -68,42 +66,6 @@ struct SpirvStats { // OpConstant f64 histogram, value -> count. std::unordered_map f64_constant_hist; - // Enum histogram, operand type -> operand value -> count. - std::unordered_map> - enum_hist; - - // Histogram of all non-id single words. - // pair -> value -> count. - // This is a generalization of enum_hist, also includes literal integers and - // masks. - std::map, std::map> - operand_slot_non_id_words_hist; - - // Historgam of descriptors generated by IdDescriptorCollection. - // Descriptor -> count. - std::unordered_map id_descriptor_hist; - - // Debut labels for id descriptors, descriptor -> label. - std::unordered_map id_descriptor_labels; - - // Historgam of descriptors generated by IdDescriptorCollection for every - // operand slot. pair -> descriptor -> count. - std::map, std::map> - operand_slot_id_descriptor_hist; - - // Histogram of literal strings, sharded by opcodes, opcode -> string -> - // count. - // This is suboptimal if an opcode has multiple literal string operands, - // as it wouldn't differentiate between operands. - std::unordered_map> - literal_strings_hist; - - // Markov chain histograms: - // opcode -> next(opcode | (num_operands << 16)) -> count. - // See also opcode_and_num_operands_hist, which collects global statistics. - std::unordered_map> - opcode_and_num_operands_markov_hist; - // Used to collect statistics on opcodes triggering other opcodes. // Container scheme: gap between instructions -> cue opcode -> later opcode // -> count. @@ -125,6 +87,7 @@ spv_result_t AggregateStats(const spv_context_t& context, const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic, SpirvStats* stats); -} // namespace libspirv +} // namespace stats +} // namespace spvtools -#endif // LIBSPIRV_SPIRV_STATS_H_ +#endif // TOOLS_STATS_SPIRV_STATS_H_ diff --git a/3rdparty/spirv-tools/tools/stats/stats.cpp b/3rdparty/spirv-tools/tools/stats/stats.cpp index 2e525fabb..256ec1e1b 100644 --- a/3rdparty/spirv-tools/tools/stats/stats.cpp +++ b/3rdparty/spirv-tools/tools/stats/stats.cpp @@ -17,14 +17,13 @@ #include #include #include +#include -#include "source/spirv_stats.h" #include "source/table.h" #include "spirv-tools/libspirv.h" -#include "stats_analyzer.h" #include "tools/io.h" - -using libspirv::SpirvStats; +#include "tools/stats/spirv_stats.h" +#include "tools/stats/stats_analyzer.h" namespace { @@ -49,35 +48,6 @@ Options: -v, --verbose Print additional info to stderr. - - --codegen_opcode_hist - Output generated C++ code for opcode histogram. - This flag disables non-C++ output. - - --codegen_opcode_and_num_operands_hist - Output generated C++ code for opcode_and_num_operands - histogram. - This flag disables non-C++ output. - - --codegen_opcode_and_num_operands_markov_huffman_codecs - Output generated C++ code for Huffman codecs of - opcode_and_num_operands Markov chain. - This flag disables non-C++ output. - - --codegen_literal_string_huffman_codecs - Output generated C++ code for Huffman codecs for - literal strings. - This flag disables non-C++ output. - - --codegen_non_id_word_huffman_codecs - Output generated C++ code for Huffman codecs for - single-word non-id slots. - This flag disables non-C++ output. - - --codegen_id_descriptor_huffman_codecs - Output generated C++ code for Huffman codecs for - common id descriptors. - This flag disables non-C++ output. )", argv0, argv0, argv0); } @@ -111,13 +81,6 @@ int main(int argc, char** argv) { bool expect_output_path = false; bool verbose = false; - bool export_text = true; - bool codegen_opcode_hist = false; - bool codegen_opcode_and_num_operands_hist = false; - bool codegen_opcode_and_num_operands_markov_huffman_codecs = false; - bool codegen_literal_string_huffman_codecs = false; - bool codegen_non_id_word_huffman_codecs = false; - bool codegen_id_descriptor_huffman_codecs = false; std::vector paths; const char* output_path = nullptr; @@ -129,29 +92,6 @@ int main(int argc, char** argv) { PrintUsage(argv[0]); continue_processing = false; return_code = 0; - } else if (0 == strcmp(cur_arg, "--codegen_opcode_hist")) { - codegen_opcode_hist = true; - export_text = false; - } else if (0 == - strcmp(cur_arg, "--codegen_opcode_and_num_operands_hist")) { - codegen_opcode_and_num_operands_hist = true; - export_text = false; - } else if (strcmp( - "--codegen_opcode_and_num_operands_markov_huffman_codecs", - cur_arg) == 0) { - codegen_opcode_and_num_operands_markov_huffman_codecs = true; - export_text = false; - } else if (0 == - strcmp(cur_arg, "--codegen_literal_string_huffman_codecs")) { - codegen_literal_string_huffman_codecs = true; - export_text = false; - } else if (0 == strcmp(cur_arg, "--codegen_non_id_word_huffman_codecs")) { - codegen_non_id_word_huffman_codecs = true; - export_text = false; - } else if (0 == - strcmp(cur_arg, "--codegen_id_descriptor_huffman_codecs")) { - codegen_id_descriptor_huffman_codecs = true; - export_text = false; } else if (0 == strcmp(cur_arg, "--verbose") || 0 == strcmp(cur_arg, "-v")) { verbose = true; @@ -181,9 +121,9 @@ int main(int argc, char** argv) { std::cerr << "Processing " << paths.size() << " files..." << std::endl; ScopedContext ctx(SPV_ENV_UNIVERSAL_1_1); - libspirv::SetContextMessageConsumer(ctx.context, DiagnosticsMessageHandler); + spvtools::SetContextMessageConsumer(ctx.context, DiagnosticsMessageHandler); - libspirv::SpirvStats stats; + spvtools::stats::SpirvStats stats; stats.opcode_markov_hist.resize(1); for (size_t index = 0; index < paths.size(); ++index) { @@ -197,15 +137,15 @@ int main(int argc, char** argv) { std::vector contents; if (!ReadFile(path, "rb", &contents)) return 1; - if (SPV_SUCCESS != libspirv::AggregateStats(*ctx.context, contents.data(), - contents.size(), nullptr, - &stats)) { + if (SPV_SUCCESS != + spvtools::stats::AggregateStats(*ctx.context, contents.data(), + contents.size(), nullptr, &stats)) { std::cerr << "error: Failed to aggregate stats for " << path << std::endl; return 1; } } - StatsAnalyzer analyzer(stats); + spvtools::stats::StatsAnalyzer analyzer(stats); std::ofstream fout; if (output_path) { @@ -217,57 +157,24 @@ int main(int argc, char** argv) { } std::ostream& out = fout.is_open() ? fout : std::cout; + out << std::endl; + analyzer.WriteVersion(out); + analyzer.WriteGenerator(out); - if (export_text) { - out << std::endl; - analyzer.WriteVersion(out); - analyzer.WriteGenerator(out); + out << std::endl; + analyzer.WriteCapability(out); - out << std::endl; - analyzer.WriteCapability(out); + out << std::endl; + analyzer.WriteExtension(out); - out << std::endl; - analyzer.WriteExtension(out); + out << std::endl; + analyzer.WriteOpcode(out); - out << std::endl; - analyzer.WriteOpcode(out); + out << std::endl; + analyzer.WriteOpcodeMarkov(out); - out << std::endl; - analyzer.WriteOpcodeMarkov(out); - - out << std::endl; - analyzer.WriteConstantLiterals(out); - } - - if (codegen_opcode_hist) { - out << std::endl; - analyzer.WriteCodegenOpcodeHist(out); - } - - if (codegen_opcode_and_num_operands_hist) { - out << std::endl; - analyzer.WriteCodegenOpcodeAndNumOperandsHist(out); - } - - if (codegen_opcode_and_num_operands_markov_huffman_codecs) { - out << std::endl; - analyzer.WriteCodegenOpcodeAndNumOperandsMarkovHuffmanCodecs(out); - } - - if (codegen_literal_string_huffman_codecs) { - out << std::endl; - analyzer.WriteCodegenLiteralStringHuffmanCodecs(out); - } - - if (codegen_non_id_word_huffman_codecs) { - out << std::endl; - analyzer.WriteCodegenNonIdWordHuffmanCodecs(out); - } - - if (codegen_id_descriptor_huffman_codecs) { - out << std::endl; - analyzer.WriteCodegenIdDescriptorHuffmanCodecs(out); - } + out << std::endl; + analyzer.WriteConstantLiterals(out); return 0; } diff --git a/3rdparty/spirv-tools/tools/stats/stats_analyzer.cpp b/3rdparty/spirv-tools/tools/stats/stats_analyzer.cpp index 7ce56c965..6d4cabbf6 100644 --- a/3rdparty/spirv-tools/tools/stats/stats_analyzer.cpp +++ b/3rdparty/spirv-tools/tools/stats/stats_analyzer.cpp @@ -12,357 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "stats_analyzer.h" +#include "tools/stats/stats_analyzer.h" #include #include #include #include +#include #include +#include #include #include +#include #include -#include "latest_version_spirv_header.h" #include "source/comp/markv_model.h" #include "source/enum_string_mapping.h" +#include "source/latest_version_spirv_header.h" #include "source/opcode.h" #include "source/operand.h" #include "source/spirv_constant.h" -#include "source/util/huffman_codec.h" - -using libspirv::SpirvStats; -using spvutils::HuffmanCodec; +namespace spvtools { +namespace stats { namespace { // Signals that the value is not in the coding scheme and a fallback method // needs to be used. const uint64_t kMarkvNoneOfTheAbove = - spvtools::MarkvModel::GetMarkvNoneOfTheAbove(); - -inline uint32_t CombineOpcodeAndNumOperands(uint32_t opcode, - uint32_t num_operands) { - return opcode | (num_operands << 16); -} - -// Returns all SPIR-V v1.2 opcodes. -std::vector GetAllOpcodes() { - return std::vector({ - SpvOpNop, - SpvOpUndef, - SpvOpSourceContinued, - SpvOpSource, - SpvOpSourceExtension, - SpvOpName, - SpvOpMemberName, - SpvOpString, - SpvOpLine, - SpvOpExtension, - SpvOpExtInstImport, - SpvOpExtInst, - SpvOpMemoryModel, - SpvOpEntryPoint, - SpvOpExecutionMode, - SpvOpCapability, - SpvOpTypeVoid, - SpvOpTypeBool, - SpvOpTypeInt, - SpvOpTypeFloat, - SpvOpTypeVector, - SpvOpTypeMatrix, - SpvOpTypeImage, - SpvOpTypeSampler, - SpvOpTypeSampledImage, - SpvOpTypeArray, - SpvOpTypeRuntimeArray, - SpvOpTypeStruct, - SpvOpTypeOpaque, - SpvOpTypePointer, - SpvOpTypeFunction, - SpvOpTypeEvent, - SpvOpTypeDeviceEvent, - SpvOpTypeReserveId, - SpvOpTypeQueue, - SpvOpTypePipe, - SpvOpTypeForwardPointer, - SpvOpConstantTrue, - SpvOpConstantFalse, - SpvOpConstant, - SpvOpConstantComposite, - SpvOpConstantSampler, - SpvOpConstantNull, - SpvOpSpecConstantTrue, - SpvOpSpecConstantFalse, - SpvOpSpecConstant, - SpvOpSpecConstantComposite, - SpvOpSpecConstantOp, - SpvOpFunction, - SpvOpFunctionParameter, - SpvOpFunctionEnd, - SpvOpFunctionCall, - SpvOpVariable, - SpvOpImageTexelPointer, - SpvOpLoad, - SpvOpStore, - SpvOpCopyMemory, - SpvOpCopyMemorySized, - SpvOpAccessChain, - SpvOpInBoundsAccessChain, - SpvOpPtrAccessChain, - SpvOpArrayLength, - SpvOpGenericPtrMemSemantics, - SpvOpInBoundsPtrAccessChain, - SpvOpDecorate, - SpvOpMemberDecorate, - SpvOpDecorationGroup, - SpvOpGroupDecorate, - SpvOpGroupMemberDecorate, - SpvOpVectorExtractDynamic, - SpvOpVectorInsertDynamic, - SpvOpVectorShuffle, - SpvOpCompositeConstruct, - SpvOpCompositeExtract, - SpvOpCompositeInsert, - SpvOpCopyObject, - SpvOpTranspose, - SpvOpSampledImage, - SpvOpImageSampleImplicitLod, - SpvOpImageSampleExplicitLod, - SpvOpImageSampleDrefImplicitLod, - SpvOpImageSampleDrefExplicitLod, - SpvOpImageSampleProjImplicitLod, - SpvOpImageSampleProjExplicitLod, - SpvOpImageSampleProjDrefImplicitLod, - SpvOpImageSampleProjDrefExplicitLod, - SpvOpImageFetch, - SpvOpImageGather, - SpvOpImageDrefGather, - SpvOpImageRead, - SpvOpImageWrite, - SpvOpImage, - SpvOpImageQueryFormat, - SpvOpImageQueryOrder, - SpvOpImageQuerySizeLod, - SpvOpImageQuerySize, - SpvOpImageQueryLod, - SpvOpImageQueryLevels, - SpvOpImageQuerySamples, - SpvOpConvertFToU, - SpvOpConvertFToS, - SpvOpConvertSToF, - SpvOpConvertUToF, - SpvOpUConvert, - SpvOpSConvert, - SpvOpFConvert, - SpvOpQuantizeToF16, - SpvOpConvertPtrToU, - SpvOpSatConvertSToU, - SpvOpSatConvertUToS, - SpvOpConvertUToPtr, - SpvOpPtrCastToGeneric, - SpvOpGenericCastToPtr, - SpvOpGenericCastToPtrExplicit, - SpvOpBitcast, - SpvOpSNegate, - SpvOpFNegate, - SpvOpIAdd, - SpvOpFAdd, - SpvOpISub, - SpvOpFSub, - SpvOpIMul, - SpvOpFMul, - SpvOpUDiv, - SpvOpSDiv, - SpvOpFDiv, - SpvOpUMod, - SpvOpSRem, - SpvOpSMod, - SpvOpFRem, - SpvOpFMod, - SpvOpVectorTimesScalar, - SpvOpMatrixTimesScalar, - SpvOpVectorTimesMatrix, - SpvOpMatrixTimesVector, - SpvOpMatrixTimesMatrix, - SpvOpOuterProduct, - SpvOpDot, - SpvOpIAddCarry, - SpvOpISubBorrow, - SpvOpUMulExtended, - SpvOpSMulExtended, - SpvOpAny, - SpvOpAll, - SpvOpIsNan, - SpvOpIsInf, - SpvOpIsFinite, - SpvOpIsNormal, - SpvOpSignBitSet, - SpvOpLessOrGreater, - SpvOpOrdered, - SpvOpUnordered, - SpvOpLogicalEqual, - SpvOpLogicalNotEqual, - SpvOpLogicalOr, - SpvOpLogicalAnd, - SpvOpLogicalNot, - SpvOpSelect, - SpvOpIEqual, - SpvOpINotEqual, - SpvOpUGreaterThan, - SpvOpSGreaterThan, - SpvOpUGreaterThanEqual, - SpvOpSGreaterThanEqual, - SpvOpULessThan, - SpvOpSLessThan, - SpvOpULessThanEqual, - SpvOpSLessThanEqual, - SpvOpFOrdEqual, - SpvOpFUnordEqual, - SpvOpFOrdNotEqual, - SpvOpFUnordNotEqual, - SpvOpFOrdLessThan, - SpvOpFUnordLessThan, - SpvOpFOrdGreaterThan, - SpvOpFUnordGreaterThan, - SpvOpFOrdLessThanEqual, - SpvOpFUnordLessThanEqual, - SpvOpFOrdGreaterThanEqual, - SpvOpFUnordGreaterThanEqual, - SpvOpShiftRightLogical, - SpvOpShiftRightArithmetic, - SpvOpShiftLeftLogical, - SpvOpBitwiseOr, - SpvOpBitwiseXor, - SpvOpBitwiseAnd, - SpvOpNot, - SpvOpBitFieldInsert, - SpvOpBitFieldSExtract, - SpvOpBitFieldUExtract, - SpvOpBitReverse, - SpvOpBitCount, - SpvOpDPdx, - SpvOpDPdy, - SpvOpFwidth, - SpvOpDPdxFine, - SpvOpDPdyFine, - SpvOpFwidthFine, - SpvOpDPdxCoarse, - SpvOpDPdyCoarse, - SpvOpFwidthCoarse, - SpvOpEmitVertex, - SpvOpEndPrimitive, - SpvOpEmitStreamVertex, - SpvOpEndStreamPrimitive, - SpvOpControlBarrier, - SpvOpMemoryBarrier, - SpvOpAtomicLoad, - SpvOpAtomicStore, - SpvOpAtomicExchange, - SpvOpAtomicCompareExchange, - SpvOpAtomicCompareExchangeWeak, - SpvOpAtomicIIncrement, - SpvOpAtomicIDecrement, - SpvOpAtomicIAdd, - SpvOpAtomicISub, - SpvOpAtomicSMin, - SpvOpAtomicUMin, - SpvOpAtomicSMax, - SpvOpAtomicUMax, - SpvOpAtomicAnd, - SpvOpAtomicOr, - SpvOpAtomicXor, - SpvOpPhi, - SpvOpLoopMerge, - SpvOpSelectionMerge, - SpvOpLabel, - SpvOpBranch, - SpvOpBranchConditional, - SpvOpSwitch, - SpvOpKill, - SpvOpReturn, - SpvOpReturnValue, - SpvOpUnreachable, - SpvOpLifetimeStart, - SpvOpLifetimeStop, - SpvOpGroupAsyncCopy, - SpvOpGroupWaitEvents, - SpvOpGroupAll, - SpvOpGroupAny, - SpvOpGroupBroadcast, - SpvOpGroupIAdd, - SpvOpGroupFAdd, - SpvOpGroupFMin, - SpvOpGroupUMin, - SpvOpGroupSMin, - SpvOpGroupFMax, - SpvOpGroupUMax, - SpvOpGroupSMax, - SpvOpReadPipe, - SpvOpWritePipe, - SpvOpReservedReadPipe, - SpvOpReservedWritePipe, - SpvOpReserveReadPipePackets, - SpvOpReserveWritePipePackets, - SpvOpCommitReadPipe, - SpvOpCommitWritePipe, - SpvOpIsValidReserveId, - SpvOpGetNumPipePackets, - SpvOpGetMaxPipePackets, - SpvOpGroupReserveReadPipePackets, - SpvOpGroupReserveWritePipePackets, - SpvOpGroupCommitReadPipe, - SpvOpGroupCommitWritePipe, - SpvOpEnqueueMarker, - SpvOpEnqueueKernel, - SpvOpGetKernelNDrangeSubGroupCount, - SpvOpGetKernelNDrangeMaxSubGroupSize, - SpvOpGetKernelWorkGroupSize, - SpvOpGetKernelPreferredWorkGroupSizeMultiple, - SpvOpRetainEvent, - SpvOpReleaseEvent, - SpvOpCreateUserEvent, - SpvOpIsValidEvent, - SpvOpSetUserEventStatus, - SpvOpCaptureEventProfilingInfo, - SpvOpGetDefaultQueue, - SpvOpBuildNDRange, - SpvOpImageSparseSampleImplicitLod, - SpvOpImageSparseSampleExplicitLod, - SpvOpImageSparseSampleDrefImplicitLod, - SpvOpImageSparseSampleDrefExplicitLod, - SpvOpImageSparseSampleProjImplicitLod, - SpvOpImageSparseSampleProjExplicitLod, - SpvOpImageSparseSampleProjDrefImplicitLod, - SpvOpImageSparseSampleProjDrefExplicitLod, - SpvOpImageSparseFetch, - SpvOpImageSparseGather, - SpvOpImageSparseDrefGather, - SpvOpImageSparseTexelsResident, - SpvOpNoLine, - SpvOpAtomicFlagTestAndSet, - SpvOpAtomicFlagClear, - SpvOpImageSparseRead, - SpvOpSizeOf, - SpvOpTypePipeStorage, - SpvOpConstantPipeStorage, - SpvOpCreatePipeFromPipeStorage, - SpvOpGetKernelLocalSizeForSubgroupCount, - SpvOpGetKernelMaxNumSubgroups, - SpvOpTypeNamedBarrier, - SpvOpNamedBarrierInitialize, - SpvOpMemoryNamedBarrier, - SpvOpModuleProcessed, - SpvOpExecutionModeId, - SpvOpDecorateId, - SpvOpSubgroupBallotKHR, - SpvOpSubgroupFirstInvocationKHR, - SpvOpSubgroupAllKHR, - SpvOpSubgroupAnyKHR, - SpvOpSubgroupAllEqualKHR, - SpvOpSubgroupReadInvocationKHR, - }); -} + comp::MarkvModel::GetMarkvNoneOfTheAbove(); std::string GetVersionString(uint32_t word) { std::stringstream ss; @@ -380,7 +58,7 @@ std::string GetOpcodeString(uint32_t word) { } std::string GetCapabilityString(uint32_t word) { - return libspirv::CapabilityToString(static_cast(word)); + return CapabilityToString(static_cast(word)); } template @@ -420,7 +98,7 @@ std::unordered_map GetPrevalence( // |label_from_key| is used to convert |Key| to label. template void WriteFreq(std::ostream& out, const std::unordered_map& freq, - std::string (*label_from_key)(Key), double threshold = 0.001) { + std::string (*label_from_key)(Key)) { std::vector> sorted_freq(freq.begin(), freq.end()); std::sort(sorted_freq.begin(), sorted_freq.end(), [](const std::pair& left, @@ -429,32 +107,12 @@ void WriteFreq(std::ostream& out, const std::unordered_map& freq, }); for (const auto& pair : sorted_freq) { - if (pair.second < threshold) break; + if (pair.second < 0.001) break; out << label_from_key(pair.first) << " " << pair.second * 100.0 << "%" << std::endl; } } -// Writes |hist| to |out| sorted by count in the following format: -// LABEL3 100 -// LABEL1 50 -// LABEL2 10 -// |label_from_key| is used to convert |Key| to label. -template -void WriteHist(std::ostream& out, const std::unordered_map& hist, - std::string (*label_from_key)(Key)) { - std::vector> sorted_hist(hist.begin(), hist.end()); - std::sort(sorted_hist.begin(), sorted_hist.end(), - [](const std::pair& left, - const std::pair& right) { - return left.second > right.second; - }); - - for (const auto& pair : sorted_hist) { - out << label_from_key(pair.first) << " " << pair.second << std::endl; - } -} - } // namespace StatsAnalyzer::StatsAnalyzer(const SpirvStats& stats) : stats_(stats) { @@ -573,302 +231,5 @@ void StatsAnalyzer::WriteOpcodeMarkov(std::ostream& out) { } } -void StatsAnalyzer::WriteCodegenOpcodeHist(std::ostream& out) { - auto all_opcodes = GetAllOpcodes(); - - // uint64_t is used because kMarkvNoneOfTheAbove is outside of uint32_t range. - out << "std::map GetOpcodeHist() {\n" - << " return std::map({\n"; - - uint32_t total = 0; - for (const auto& kv : stats_.opcode_hist) { - total += kv.second; - } - - for (uint32_t opcode : all_opcodes) { - const auto it = stats_.opcode_hist.find(opcode); - const uint32_t count = it == stats_.opcode_hist.end() ? 0 : it->second; - const double kMaxValue = 1000.0; - uint32_t value = uint32_t(kMaxValue * double(count) / double(total)); - if (value == 0) value = 1; - out << " { SpvOp" << GetOpcodeString(opcode) << ", " << value << " },\n"; - } - - // Add kMarkvNoneOfTheAbove as a signal for unknown opcode. - out << " { kMarkvNoneOfTheAbove, " << 10 << " },\n"; - out << " });\n}\n"; -} - -void StatsAnalyzer::WriteCodegenOpcodeAndNumOperandsHist(std::ostream& out) { - out << "std::map GetOpcodeAndNumOperandsHist() {\n" - << " return std::map({\n"; - - uint32_t total = 0; - for (const auto& kv : stats_.opcode_and_num_operands_hist) { - total += kv.second; - } - - uint32_t left_out = 0; - - for (const auto& kv : stats_.opcode_and_num_operands_hist) { - const uint32_t count = kv.second; - const double kFrequentEnoughToAnalyze = 0.001; - const uint32_t opcode_and_num_operands = kv.first; - const uint32_t opcode = opcode_and_num_operands & 0xFFFF; - const uint32_t num_operands = opcode_and_num_operands >> 16; - - if (opcode == SpvOpTypeStruct || - double(count) / double(total) < kFrequentEnoughToAnalyze) { - left_out += count; - continue; - } - - out << " { CombineOpcodeAndNumOperands(SpvOp" - << spvOpcodeString(SpvOp(opcode)) << ", " << num_operands << "), " - << count << " },\n"; - } - - // Heuristic. - const uint32_t none_of_the_above = std::max(1, int(left_out + total * 0.01)); - out << " { kMarkvNoneOfTheAbove, " << none_of_the_above << " },\n"; - out << " });\n}\n"; -} - -void StatsAnalyzer::WriteCodegenOpcodeAndNumOperandsMarkovHuffmanCodecs( - std::ostream& out) { - out << "std::map>>\n" - << "GetOpcodeAndNumOperandsMarkovHuffmanCodecs() {\n" - << " std::map>> " - << "codecs;\n"; - - for (const auto& kv : stats_.opcode_and_num_operands_markov_hist) { - const uint32_t prev_opcode = kv.first; - const double kFrequentEnoughToAnalyze = 0.001; - if (opcode_freq_[prev_opcode] < kFrequentEnoughToAnalyze) continue; - - const std::unordered_map& hist = kv.second; - - uint32_t total = 0; - for (const auto& pair : hist) { - total += pair.second; - } - - uint32_t left_out = 0; - - std::map processed_hist; - for (const auto& pair : hist) { - const uint32_t opcode_and_num_operands = pair.first; - const uint32_t opcode = opcode_and_num_operands & 0xFFFF; - - if (opcode == SpvOpTypeStruct) continue; - - const uint32_t num_operands = opcode_and_num_operands >> 16; - const uint32_t count = pair.second; - const double posterior_freq = double(count) / double(total); - - if (opcode_freq_[opcode] < kFrequentEnoughToAnalyze && - posterior_freq < kFrequentEnoughToAnalyze) { - left_out += count; - continue; - } - processed_hist.emplace(CombineOpcodeAndNumOperands(opcode, num_operands), - count); - } - - // Heuristic. - processed_hist.emplace(kMarkvNoneOfTheAbove, - std::max(1, int(left_out + total * 0.01))); - - HuffmanCodec codec(processed_hist); - - out << " {\n"; - out << " std::unique_ptr> " - << "codec(new HuffmanCodec"; - out << codec.SerializeToText(4); - out << ");\n" << std::endl; - out << " codecs.emplace(SpvOp" << GetOpcodeString(prev_opcode) - << ", std::move(codec));\n"; - out << " }\n\n"; - } - - out << " return codecs;\n}\n"; -} - -void StatsAnalyzer::WriteCodegenLiteralStringHuffmanCodecs(std::ostream& out) { - out << "std::map>>\n" - << "GetLiteralStringHuffmanCodecs() {\n" - << " std::map>> " - << "codecs;\n"; - - for (const auto& kv : stats_.literal_strings_hist) { - const uint32_t opcode = kv.first; - - if (opcode == SpvOpName || opcode == SpvOpMemberName) continue; - - const double kOpcodeFrequentEnoughToAnalyze = 0.001; - if (opcode_freq_[opcode] < kOpcodeFrequentEnoughToAnalyze) continue; - - const std::unordered_map& hist = kv.second; - - uint32_t total = 0; - for (const auto& pair : hist) { - total += pair.second; - } - - uint32_t left_out = 0; - - std::map processed_hist; - for (const auto& pair : hist) { - const uint32_t count = pair.second; - const double freq = double(count) / double(total); - const double kStringFrequentEnoughToAnalyze = 0.001; - if (freq < kStringFrequentEnoughToAnalyze) { - left_out += count; - continue; - } - processed_hist.emplace(pair.first, count); - } - - // Heuristic. - processed_hist.emplace("kMarkvNoneOfTheAbove", - std::max(1, int(left_out + total * 0.01))); - - HuffmanCodec codec(processed_hist); - - out << " {\n"; - out << " std::unique_ptr> " - << "codec(new HuffmanCodec"; - out << codec.SerializeToText(4); - out << ");\n" << std::endl; - out << " codecs.emplace(SpvOp" << spvOpcodeString(SpvOp(opcode)) - << ", std::move(codec));\n"; - out << " }\n\n"; - } - - out << " return codecs;\n}\n"; -} - -void StatsAnalyzer::WriteCodegenNonIdWordHuffmanCodecs(std::ostream& out) { - out << "std::map, " - << "std::unique_ptr>>\n" - << "GetNonIdWordHuffmanCodecs() {\n" - << " std::map, " - << "std::unique_ptr>> codecs;\n"; - - for (const auto& kv : stats_.operand_slot_non_id_words_hist) { - const auto& opcode_and_index = kv.first; - const uint32_t opcode = opcode_and_index.first; - const uint32_t index = opcode_and_index.second; - - const double kOpcodeFrequentEnoughToAnalyze = 0.001; - if (opcode_freq_[opcode] < kOpcodeFrequentEnoughToAnalyze) continue; - - const std::map& hist = kv.second; - - uint32_t total = 0; - for (const auto& pair : hist) { - total += pair.second; - } - - uint32_t left_out = 0; - - std::map processed_hist; - for (const auto& pair : hist) { - const uint32_t word = pair.first; - const uint32_t count = pair.second; - const double freq = double(count) / double(total); - const double kWordFrequentEnoughToAnalyze = 0.003; - if (freq < kWordFrequentEnoughToAnalyze) { - left_out += count; - continue; - } - processed_hist.emplace(word, count); - } - - // Heuristic. - processed_hist.emplace(kMarkvNoneOfTheAbove, - std::max(1, int(left_out + total * 0.01))); - - HuffmanCodec codec(processed_hist); - - out << " {\n"; - out << " std::unique_ptr> " - << "codec(new HuffmanCodec"; - out << codec.SerializeToText(4); - out << ");\n" << std::endl; - out << " codecs.emplace(std::pair(SpvOp" - << spvOpcodeString(SpvOp(opcode)) << ", " << index - << "), std::move(codec));\n"; - out << " }\n\n"; - } - - out << " return codecs;\n}\n"; -} - -void StatsAnalyzer::WriteCodegenIdDescriptorHuffmanCodecs(std::ostream& out) { - out << "std::map, " - << "std::unique_ptr>>\n" - << "GetIdDescriptorHuffmanCodecs() {\n" - << " std::map, " - << "std::unique_ptr>> codecs;\n"; - - std::unordered_set descriptors_with_coding_scheme; - - for (const auto& kv : stats_.operand_slot_id_descriptor_hist) { - const auto& opcode_and_index = kv.first; - const uint32_t opcode = opcode_and_index.first; - const uint32_t index = opcode_and_index.second; - - const double kOpcodeFrequentEnoughToAnalyze = 0.003; - if (opcode_freq_[opcode] < kOpcodeFrequentEnoughToAnalyze) continue; - - const std::map& hist = kv.second; - - uint32_t total = 0; - for (const auto& pair : hist) { - total += pair.second; - } - - uint32_t left_out = 0; - - std::map processed_hist; - for (const auto& pair : hist) { - const uint32_t descriptor = pair.first; - const uint32_t count = pair.second; - const double freq = double(count) / double(total); - const double kDescriptorFrequentEnoughToAnalyze = 0.003; - if (freq < kDescriptorFrequentEnoughToAnalyze) { - left_out += count; - continue; - } - processed_hist.emplace(descriptor, count); - descriptors_with_coding_scheme.insert(descriptor); - } - - // Heuristic. - processed_hist.emplace(kMarkvNoneOfTheAbove, - std::max(1, int(left_out + total * 0.01))); - - HuffmanCodec codec(processed_hist); - - out << " {\n"; - out << " std::unique_ptr> " - << "codec(new HuffmanCodec"; - out << codec.SerializeToText(4); - out << ");\n" << std::endl; - out << " codecs.emplace(std::pair(SpvOp" - << spvOpcodeString(SpvOp(opcode)) << ", " << index - << "), std::move(codec));\n"; - out << " }\n\n"; - } - - out << " return codecs;\n}\n"; - - out << "\nstd::unordered_set GetDescriptorsWithCodingScheme() {\n" - << " std::unordered_set descriptors_with_coding_scheme = {\n"; - for (uint32_t descriptor : descriptors_with_coding_scheme) { - out << " " << descriptor << ",\n"; - } - out << " };\n"; - out << " return descriptors_with_coding_scheme;\n}\n"; -} +} // namespace stats +} // namespace spvtools diff --git a/3rdparty/spirv-tools/tools/stats/stats_analyzer.h b/3rdparty/spirv-tools/tools/stats/stats_analyzer.h index a3b91baef..f1c37bfaa 100644 --- a/3rdparty/spirv-tools/tools/stats/stats_analyzer.h +++ b/3rdparty/spirv-tools/tools/stats/stats_analyzer.h @@ -12,16 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_TOOLS_STATS_STATS_ANALYZER_H_ -#define LIBSPIRV_TOOLS_STATS_STATS_ANALYZER_H_ +#ifndef TOOLS_STATS_STATS_ANALYZER_H_ +#define TOOLS_STATS_STATS_ANALYZER_H_ +#include #include -#include "source/spirv_stats.h" +#include "tools/stats/spirv_stats.h" + +namespace spvtools { +namespace stats { class StatsAnalyzer { public: - explicit StatsAnalyzer(const libspirv::SpirvStats& stats); + explicit StatsAnalyzer(const SpirvStats& stats); // Writes respective histograms to |out|. void WriteVersion(std::ostream& out); @@ -36,37 +40,8 @@ class StatsAnalyzer { // level. void WriteOpcodeMarkov(std::ostream& out); - // Writes C++ code containing a function returning opcode histogram. - void WriteCodegenOpcodeHist(std::ostream& out); - - // Writes C++ code containing a function returning opcode_and_num_operands - // histogram. - void WriteCodegenOpcodeAndNumOperandsHist(std::ostream& out); - - // Writes C++ code containing a function returning a map of Huffman codecs - // for opcode_and_num_operands. Each Huffman codec is created for a specific - // previous opcode. - // TODO(atgoo@github.com) Write code which would contain pregenerated Huffman - // codecs, instead of code which would generate them every time. - void WriteCodegenOpcodeAndNumOperandsMarkovHuffmanCodecs(std::ostream& out); - - // Writes C++ code containing a function returning a map of Huffman codecs - // for literal strings. Each Huffman codec is created for a specific opcode. - // I.e. OpExtension and OpExtInstImport would use different codecs. - void WriteCodegenLiteralStringHuffmanCodecs(std::ostream& out); - - // Writes C++ code containing a function returning a map of Huffman codecs - // for single-word non-id operands. Each Huffman codec is created for a - // specific operand slot (opcode and operand number). - void WriteCodegenNonIdWordHuffmanCodecs(std::ostream& out); - - // Writes C++ code containing a function returning a map of Huffman codecs - // for common id descriptors. Each Huffman codec is created for a - // specific operand slot (opcode and operand number). - void WriteCodegenIdDescriptorHuffmanCodecs(std::ostream& out); - private: - const libspirv::SpirvStats& stats_; + const SpirvStats& stats_; uint32_t num_modules_; @@ -77,4 +52,7 @@ class StatsAnalyzer { std::unordered_map opcode_freq_; }; -#endif // LIBSPIRV_TOOLS_STATS_STATS_ANALYZER_H_ +} // namespace stats +} // namespace spvtools + +#endif // TOOLS_STATS_STATS_ANALYZER_H_ diff --git a/3rdparty/spirv-tools/source/message.cpp b/3rdparty/spirv-tools/tools/util/cli_consumer.cpp similarity index 51% rename from 3rdparty/spirv-tools/source/message.cpp rename to 3rdparty/spirv-tools/tools/util/cli_consumer.cpp index 030fa4e23..77db734e8 100644 --- a/3rdparty/spirv-tools/source/message.cpp +++ b/3rdparty/spirv-tools/tools/util/cli_consumer.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2016 Google Inc. +// Copyright (c) 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,43 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "message.h" +#include "tools/util/cli_consumer.h" -#include +#include namespace spvtools { +namespace utils { -std::string StringifyMessage(spv_message_level_t level, const char* source, - const spv_position_t& position, - const char* message) { - const char* level_string = nullptr; +void CLIMessageConsumer(spv_message_level_t level, const char*, + const spv_position_t& position, const char* message) { switch (level) { case SPV_MSG_FATAL: - level_string = "fatal"; - break; case SPV_MSG_INTERNAL_ERROR: - level_string = "internal error"; - break; case SPV_MSG_ERROR: - level_string = "error"; + std::cerr << "error: line " << position.index << ": " << message + << std::endl; break; case SPV_MSG_WARNING: - level_string = "warning"; + std::cout << "warning: line " << position.index << ": " << message + << std::endl; break; case SPV_MSG_INFO: - level_string = "info"; + std::cout << "info: line " << position.index << ": " << message + << std::endl; break; - case SPV_MSG_DEBUG: - level_string = "debug"; + default: break; } - std::ostringstream oss; - oss << level_string << ": "; - if (source) oss << source << ":"; - oss << position.line << ":" << position.column << ":"; - oss << position.index << ": "; - if (message) oss << message; - return oss.str(); } +} // namespace utils } // namespace spvtools diff --git a/3rdparty/spirv-tools/source/message.h b/3rdparty/spirv-tools/tools/util/cli_consumer.h similarity index 53% rename from 3rdparty/spirv-tools/source/message.h rename to 3rdparty/spirv-tools/tools/util/cli_consumer.h index 60f5d5632..ca3d91b95 100644 --- a/3rdparty/spirv-tools/source/message.h +++ b/3rdparty/spirv-tools/tools/util/cli_consumer.h @@ -1,4 +1,4 @@ -// Copyright (c) 2016 Google Inc. +// Copyright (c) 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,22 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SPIRV_TOOLS_MESSAGE_H_ -#define SPIRV_TOOLS_MESSAGE_H_ +#ifndef SOURCE_UTIL_CLI_CONSUMMER_H_ +#define SOURCE_UTIL_CLI_CONSUMMER_H_ -#include - -#include "spirv-tools/libspirv.h" +#include namespace spvtools { +namespace utils { -// A helper function to compose and return a string from the message in the -// following format: -// ": :::: " -std::string StringifyMessage(spv_message_level_t level, const char* source, - const spv_position_t& position, - const char* message); +// A message consumer that can be used by command line tools like spirv-opt and +// spirv-val to display messages. +void CLIMessageConsumer(spv_message_level_t level, const char*, + const spv_position_t& position, const char* message); +} // namespace utils } // namespace spvtools -#endif // SPIRV_TOOLS_MESSAGE_H_ +#endif // SOURCE_UTIL_CLI_CONSUMMER_H_ diff --git a/3rdparty/spirv-tools/tools/val/val.cpp b/3rdparty/spirv-tools/tools/val/val.cpp index be10aaf27..172dd121d 100644 --- a/3rdparty/spirv-tools/tools/val/val.cpp +++ b/3rdparty/spirv-tools/tools/val/val.cpp @@ -22,6 +22,7 @@ #include "source/spirv_validator_options.h" #include "spirv-tools/libspirv.hpp" #include "tools/io.h" +#include "tools/util/cli_consumer.h" void print_usage(char* argv0) { printf( @@ -44,14 +45,18 @@ Options: --max-function-args --max-control-flow-nesting-depth --max-access-chain-indexes - --relax-logcial-pointer Allow allocating an object of a pointer type and returning + --relax-logical-pointer Allow allocating an object of a pointer type and returning a pointer value from a function in logical addressing mode + --relax-block-layout Enable VK_HR_relaxed_block_layout when checking standard + uniform/storage buffer layout + --skip-block-layout Skip checking standard uniform/storage buffer layout --relax-struct-store Allow store from one struct type to a different type with compatible layout and members. --version Display validator version information. - --target-env {vulkan1.0|spv1.0|spv1.1|spv1.2} - Use Vulkan1.0/SPIR-V1.0/SPIR-V1.1/SPIR-V1.2 validation rules. + --target-env {vulkan1.0|vulkan1.1|opencl2.2|spv1.0|spv1.1|spv1.2|spv1.3|webgpu0} + Use Vulkan 1.0, Vulkan 1.1, OpenCL 2.2, SPIR-V 1.0, + SPIR-V 1.1, SPIR-V 1.2, SPIR-V 1.3 or WIP WebGPU validation rules. )", argv0, argv0); } @@ -90,14 +95,15 @@ int main(int argc, char** argv) { } } else if (0 == strcmp(cur_arg, "--version")) { printf("%s\n", spvSoftwareVersionDetailsString()); - printf("Targets:\n %s\n %s\n %s\n %s\n %s\n %s\n %s\n", + printf("Targets:\n %s\n %s\n %s\n %s\n %s\n %s\n %s\n %s\n", spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_0), spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_1), spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_2), spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_3), spvTargetEnvDescription(SPV_ENV_OPENCL_2_2), spvTargetEnvDescription(SPV_ENV_VULKAN_1_0), - spvTargetEnvDescription(SPV_ENV_VULKAN_1_1)); + spvTargetEnvDescription(SPV_ENV_VULKAN_1_1), + spvTargetEnvDescription(SPV_ENV_WEBGPU_0)); continue_processing = false; return_code = 0; } else if (0 == strcmp(cur_arg, "--help") || 0 == strcmp(cur_arg, "-h")) { @@ -119,6 +125,10 @@ int main(int argc, char** argv) { } } else if (0 == strcmp(cur_arg, "--relax-logical-pointer")) { options.SetRelaxLogicalPointer(true); + } else if (0 == strcmp(cur_arg, "--relax-block-layout")) { + options.SetRelaxBlockLayout(true); + } else if (0 == strcmp(cur_arg, "--skip-block-layout")) { + options.SetSkipBlockLayout(true); } else if (0 == strcmp(cur_arg, "--relax-struct-store")) { options.SetRelaxStructStore(true); } else if (0 == cur_arg[1]) { @@ -155,27 +165,7 @@ int main(int argc, char** argv) { if (!ReadFile(inFile, "rb", &contents)) return 1; spvtools::SpirvTools tools(target_env); - tools.SetMessageConsumer([](spv_message_level_t level, const char*, - const spv_position_t& position, - const char* message) { - switch (level) { - case SPV_MSG_FATAL: - case SPV_MSG_INTERNAL_ERROR: - case SPV_MSG_ERROR: - std::cerr << "error: " << position.index << ": " << message - << std::endl; - break; - case SPV_MSG_WARNING: - std::cout << "warning: " << position.index << ": " << message - << std::endl; - break; - case SPV_MSG_INFO: - std::cout << "info: " << position.index << ": " << message << std::endl; - break; - default: - break; - } - }); + tools.SetMessageConsumer(spvtools::utils::CLIMessageConsumer); bool succeed = tools.Validate(contents.data(), contents.size(), options); diff --git a/3rdparty/spirv-tools/utils/check_code_format.sh b/3rdparty/spirv-tools/utils/check_code_format.sh old mode 100644 new mode 100755 diff --git a/3rdparty/spirv-tools/utils/check_copyright.py b/3rdparty/spirv-tools/utils/check_copyright.py old mode 100644 new mode 100755 diff --git a/3rdparty/spirv-tools/utils/check_symbol_exports.py b/3rdparty/spirv-tools/utils/check_symbol_exports.py old mode 100644 new mode 100755 index 624b33d95..c9c0364df --- a/3rdparty/spirv-tools/utils/check_symbol_exports.py +++ b/3rdparty/spirv-tools/utils/check_symbol_exports.py @@ -35,7 +35,8 @@ def command_output(cmd, directory): p = subprocess.Popen(cmd, cwd=directory, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + stderr=subprocess.PIPE, + universal_newlines=True) (stdout, _) = p.communicate() if p.returncode != 0: raise RuntimeError('Failed to run %s in %s' % (cmd, directory)) diff --git a/3rdparty/spirv-tools/utils/fixup_fuzz_result.py b/3rdparty/spirv-tools/utils/fixup_fuzz_result.py new file mode 100755 index 000000000..9fe54a3cc --- /dev/null +++ b/3rdparty/spirv-tools/utils/fixup_fuzz_result.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# Copyright (c) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +if len(sys.argv) < 1: + print("Need file to chop"); + +with open(sys.argv[1], mode='rb') as file: + file_content = file.read() + content = file_content[:len(file_content) - (len(file_content) % 4)] + sys.stdout.write(content) + diff --git a/3rdparty/spirv-tools/utils/generate_grammar_tables.py b/3rdparty/spirv-tools/utils/generate_grammar_tables.py old mode 100644 new mode 100755 index ae948d388..aabdad505 --- a/3rdparty/spirv-tools/utils/generate_grammar_tables.py +++ b/3rdparty/spirv-tools/utils/generate_grammar_tables.py @@ -109,7 +109,7 @@ def compose_extension_list(exts): a string containing the braced list of extensions named by exts. """ return "{" + ", ".join( - ['libspirv::Extension::k{}'.format(e) for e in exts]) + "}" + ['spvtools::Extension::k{}'.format(e) for e in exts]) + "}" def get_extension_array_name(extensions): @@ -133,7 +133,7 @@ def generate_extension_arrays(extensions): """ extensions = sorted(set([tuple(e) for e in extensions if e])) arrays = [ - 'static const libspirv::Extension {}[] = {};'.format( + 'static const spvtools::Extension {}[] = {};'.format( get_extension_array_name(e), compose_extension_list(e)) for e in extensions] return '\n'.join(arrays) diff --git a/3rdparty/spirv-tools/utils/generate_language_headers.py b/3rdparty/spirv-tools/utils/generate_language_headers.py old mode 100644 new mode 100755 diff --git a/3rdparty/spirv-tools/utils/generate_registry_tables.py b/3rdparty/spirv-tools/utils/generate_registry_tables.py old mode 100644 new mode 100755 diff --git a/3rdparty/spirv-tools/utils/generate_vim_syntax.py b/3rdparty/spirv-tools/utils/generate_vim_syntax.py old mode 100644 new mode 100755 diff --git a/3rdparty/spirv-tools/utils/update_build_version.py b/3rdparty/spirv-tools/utils/update_build_version.py old mode 100644 new mode 100755 diff --git a/scripts/shaderc.lua b/scripts/shaderc.lua index 3a973f90c..b26a1f687 100644 --- a/scripts/shaderc.lua +++ b/scripts/shaderc.lua @@ -17,6 +17,7 @@ project "spirv-opt" path.join(SPIRV_TOOLS, "include"), path.join(SPIRV_TOOLS, "include/generated"), path.join(SPIRV_TOOLS, "source"), + path.join(SPIRV_TOOLS), path.join(SPIRV_TOOLS, "external/SPIRV-Headers/include"), } @@ -25,95 +26,102 @@ project "spirv-opt" path.join(SPIRV_TOOLS, "source/opt/**.h"), -- libspirv - path.join(SPIRV_TOOLS, "source/util/bitutils.h"), - path.join(SPIRV_TOOLS, "source/util/bit_stream.h"), - path.join(SPIRV_TOOLS, "source/util/hex_float.h"), - path.join(SPIRV_TOOLS, "source/util/parse_number.h"), - path.join(SPIRV_TOOLS, "source/util/string_utils.h"), - path.join(SPIRV_TOOLS, "source/util/timer.h"), + path.join(SPIRV_TOOLS, "source/assembly_grammar.cpp"), path.join(SPIRV_TOOLS, "source/assembly_grammar.h"), + path.join(SPIRV_TOOLS, "source/binary.cpp"), path.join(SPIRV_TOOLS, "source/binary.h"), path.join(SPIRV_TOOLS, "source/cfa.h"), + path.join(SPIRV_TOOLS, "source/diagnostic.cpp"), path.join(SPIRV_TOOLS, "source/diagnostic.h"), + path.join(SPIRV_TOOLS, "source/disassemble.cpp"), path.join(SPIRV_TOOLS, "source/disassemble.h"), path.join(SPIRV_TOOLS, "source/enum_set.h"), + path.join(SPIRV_TOOLS, "source/enum_string_mapping.cpp"), path.join(SPIRV_TOOLS, "source/enum_string_mapping.h"), + path.join(SPIRV_TOOLS, "source/ext_inst.cpp"), path.join(SPIRV_TOOLS, "source/ext_inst.h"), + path.join(SPIRV_TOOLS, "source/extensions.cpp"), path.join(SPIRV_TOOLS, "source/extensions.h"), + path.join(SPIRV_TOOLS, "source/id_descriptor.cpp"), path.join(SPIRV_TOOLS, "source/id_descriptor.h"), path.join(SPIRV_TOOLS, "source/instruction.h"), path.join(SPIRV_TOOLS, "source/latest_version_glsl_std_450_header.h"), path.join(SPIRV_TOOLS, "source/latest_version_opencl_std_header.h"), path.join(SPIRV_TOOLS, "source/latest_version_spirv_header.h"), + path.join(SPIRV_TOOLS, "source/libspirv.cpp"), path.join(SPIRV_TOOLS, "source/macro.h"), + path.join(SPIRV_TOOLS, "source/name_mapper.cpp"), path.join(SPIRV_TOOLS, "source/name_mapper.h"), + path.join(SPIRV_TOOLS, "source/opcode.cpp"), path.join(SPIRV_TOOLS, "source/opcode.h"), + path.join(SPIRV_TOOLS, "source/operand.cpp"), path.join(SPIRV_TOOLS, "source/operand.h"), + path.join(SPIRV_TOOLS, "source/parsed_operand.cpp"), path.join(SPIRV_TOOLS, "source/parsed_operand.h"), + path.join(SPIRV_TOOLS, "source/print.cpp"), path.join(SPIRV_TOOLS, "source/print.h"), + path.join(SPIRV_TOOLS, "source/software_version.cpp"), path.join(SPIRV_TOOLS, "source/spirv_constant.h"), path.join(SPIRV_TOOLS, "source/spirv_definition.h"), - path.join(SPIRV_TOOLS, "source/spirv_endian.h"), - path.join(SPIRV_TOOLS, "source/spirv_target_env.h"), - path.join(SPIRV_TOOLS, "source/spirv_validator_options.h"), - path.join(SPIRV_TOOLS, "source/table.h"), - path.join(SPIRV_TOOLS, "source/text.h"), - path.join(SPIRV_TOOLS, "source/text_handler.h"), - path.join(SPIRV_TOOLS, "source/validate.h"), - path.join(SPIRV_TOOLS, "source/util/bit_stream.cpp"), - path.join(SPIRV_TOOLS, "source/util/parse_number.cpp"), - path.join(SPIRV_TOOLS, "source/util/string_utils.cpp"), - path.join(SPIRV_TOOLS, "source/assembly_grammar.cpp"), - path.join(SPIRV_TOOLS, "source/binary.cpp"), - path.join(SPIRV_TOOLS, "source/diagnostic.cpp"), - path.join(SPIRV_TOOLS, "source/disassemble.cpp"), - path.join(SPIRV_TOOLS, "source/enum_string_mapping.cpp"), - path.join(SPIRV_TOOLS, "source/ext_inst.cpp"), - path.join(SPIRV_TOOLS, "source/extensions.cpp"), - path.join(SPIRV_TOOLS, "source/id_descriptor.cpp"), - path.join(SPIRV_TOOLS, "source/libspirv.cpp"), - path.join(SPIRV_TOOLS, "source/message.cpp"), - path.join(SPIRV_TOOLS, "source/name_mapper.cpp"), - path.join(SPIRV_TOOLS, "source/opcode.cpp"), - path.join(SPIRV_TOOLS, "source/operand.cpp"), - path.join(SPIRV_TOOLS, "source/parsed_operand.cpp"), - path.join(SPIRV_TOOLS, "source/print.cpp"), - path.join(SPIRV_TOOLS, "source/software_version.cpp"), path.join(SPIRV_TOOLS, "source/spirv_endian.cpp"), - path.join(SPIRV_TOOLS, "source/spirv_stats.cpp"), + path.join(SPIRV_TOOLS, "source/spirv_endian.h"), path.join(SPIRV_TOOLS, "source/spirv_target_env.cpp"), + path.join(SPIRV_TOOLS, "source/spirv_target_env.h"), path.join(SPIRV_TOOLS, "source/spirv_validator_options.cpp"), + path.join(SPIRV_TOOLS, "source/spirv_validator_options.h"), path.join(SPIRV_TOOLS, "source/table.cpp"), + path.join(SPIRV_TOOLS, "source/table.h"), path.join(SPIRV_TOOLS, "source/text.cpp"), + path.join(SPIRV_TOOLS, "source/text.h"), path.join(SPIRV_TOOLS, "source/text_handler.cpp"), - path.join(SPIRV_TOOLS, "source/validate.cpp"), - path.join(SPIRV_TOOLS, "source/validate_adjacency.cpp"), - path.join(SPIRV_TOOLS, "source/validate_arithmetics.cpp"), - path.join(SPIRV_TOOLS, "source/validate_atomics.cpp"), - path.join(SPIRV_TOOLS, "source/validate_barriers.cpp"), - path.join(SPIRV_TOOLS, "source/validate_bitwise.cpp"), - path.join(SPIRV_TOOLS, "source/validate_builtins.cpp"), - path.join(SPIRV_TOOLS, "source/validate_capability.cpp"), - path.join(SPIRV_TOOLS, "source/validate_cfg.cpp"), - path.join(SPIRV_TOOLS, "source/validate_composites.cpp"), - path.join(SPIRV_TOOLS, "source/validate_conversion.cpp"), - path.join(SPIRV_TOOLS, "source/validate_datarules.cpp"), - path.join(SPIRV_TOOLS, "source/validate_decorations.cpp"), - path.join(SPIRV_TOOLS, "source/validate_derivatives.cpp"), - path.join(SPIRV_TOOLS, "source/validate_ext_inst.cpp"), - path.join(SPIRV_TOOLS, "source/validate_id.cpp"), - path.join(SPIRV_TOOLS, "source/validate_image.cpp"), - path.join(SPIRV_TOOLS, "source/validate_instruction.cpp"), - path.join(SPIRV_TOOLS, "source/validate_layout.cpp"), - path.join(SPIRV_TOOLS, "source/validate_literals.cpp"), - path.join(SPIRV_TOOLS, "source/validate_logicals.cpp"), - path.join(SPIRV_TOOLS, "source/validate_primitives.cpp"), - path.join(SPIRV_TOOLS, "source/validate_type_unique.cpp"), - path.join(SPIRV_TOOLS, "source/val/decoration.h"), + path.join(SPIRV_TOOLS, "source/text_handler.h"), + path.join(SPIRV_TOOLS, "source/util/bit_vector.cpp"), + path.join(SPIRV_TOOLS, "source/util/bit_vector.h"), + path.join(SPIRV_TOOLS, "source/util/bitutils.h"), + path.join(SPIRV_TOOLS, "source/util/hex_float.h"), + path.join(SPIRV_TOOLS, "source/util/parse_number.cpp"), + path.join(SPIRV_TOOLS, "source/util/parse_number.h"), + path.join(SPIRV_TOOLS, "source/util/string_utils.cpp"), + path.join(SPIRV_TOOLS, "source/util/string_utils.h"), + path.join(SPIRV_TOOLS, "source/util/timer.h"), path.join(SPIRV_TOOLS, "source/val/basic_block.cpp"), path.join(SPIRV_TOOLS, "source/val/construct.cpp"), + path.join(SPIRV_TOOLS, "source/val/decoration.h"), path.join(SPIRV_TOOLS, "source/val/function.cpp"), path.join(SPIRV_TOOLS, "source/val/instruction.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_adjacency.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_annotation.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_arithmetics.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_atomics.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_barriers.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_bitwise.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_builtins.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_capability.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_cfg.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_composites.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_constants.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_conversion.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_datarules.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_debug.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_decorations.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_derivatives.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_execution_limitations.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_ext_inst.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_function.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_id.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_image.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_interfaces.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_instruction.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_layout.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_literals.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_logicals.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_memory.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_mode_setting.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_non_uniform.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_primitives.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate_type.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate.cpp"), + path.join(SPIRV_TOOLS, "source/val/validate.h"), path.join(SPIRV_TOOLS, "source/val/validation_state.cpp"), }